#!/usr/bin/env python
# coding=utf-8
"""Database specifications for an evaluation protocol based on the Iris Flower
databases from Fisher's original work."""
import numpy
# We don't know where the file data.csv will be installed
# on the user filesystem, we need to ask package management
# where it is and load it from there. We do this here.
import pkg_resources
DATAFILE = pkg_resources.resource_filename(__name__, "data.csv")
# A list of protocols we implement
PROTOCOLS = {
"proto1": {"train": range(0, 30), "test": range(30, 50)},
"proto2": {"train": range(20, 50), "test": range(0, 20)},
}
# Subsets of the database in each protocol
SUBSETS = [
"train",
"test",
]
# The types of Iris flowers in the dataset
CLASSES = [
"setosa",
"versicolor",
"virginica",
]
# The four values that were sampled
VARIABLES = [
"sepal length",
"sepal width",
"petal length",
"petal width",
]
[docs]def load():
"""Loads the data from its CSV format into an easy to dictionary of arrays"""
import csv
data = dict([(k, []) for k in CLASSES])
with open(DATAFILE, "rt") as f:
reader = csv.reader(f)
for k, row in enumerate(reader):
if not k:
continue
data[row[4]].append(numpy.array([float(z) for z in row[:4]]))
for k in CLASSES:
data[k] = numpy.vstack(data[k])
return data
[docs]def split_data(data, subset, splits):
"""Returns the data for a given protocol"""
return dict([(k, data[k][splits[subset]]) for k in data])
[docs]def get(protocol, subset, classes=CLASSES, variables=VARIABLES):
"""Returns the data subset given a particular protocol
Parameters
==========
protocol : str
one of the valid protocols supported by this interface
subset : str
one of 'train' or 'test'
classes : :py:class:`list` of :py:class:`str`
a list of strings containing the names of the classes from which you
want to have the data from
variables : :py:class:`list` of :py:class:`str`
a list of strings containg the names of the variables (features) you
want to have data from
Returns
=======
data : numpy.ndarray
The data for all the classes and variables nicely packed into one numpy
3D array. One depth represents the data for one class, one row is one
example, one column a given feature.
"""
retval = split_data(load(), subset, PROTOCOLS[protocol])
# filter variables (features)
varindex = [VARIABLES.index(k) for k in variables]
# filter class names and variable indexes at the same time
retval = dict([(k, retval[k][:, varindex]) for k in classes])
# squash the data
return numpy.array([retval[k] for k in classes])