Skip to content

Commit

Permalink
redo grid_search arg validation and a couple test fixes (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
blbarker authored Jan 9, 2017
1 parent 825ab29 commit ff560fe
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion integration-tests/tests/test_random_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_random_forest_classifier(tc):
predict_frame = model.predict(f)
assert(set(predict_frame.column_names) == set(['Class', 'Dim_1', 'Dim_2','predicted_class']))
assert(len(predict_frame.column_names) == 4)
metrics = model.test(f, 'Class')
metrics = model.test(f)
assert(metrics.accuracy == 1.0)
assert(metrics.f_measure == 1.0)
assert(metrics.precision == 1.0)
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_svm(tc):
assert(set(predicted_frame.column_names) == set(['data', 'label', 'predicted_label']))
assert(len(predicted_frame.column_names) == 3)
assert(len(f.column_names) == 2)
metrics = model.test(predicted_frame, 'label')
metrics = model.test(predicted_frame)
assert(metrics.accuracy == 1.0)
assert(metrics.f_measure == 1.0)
assert(metrics.precision == 1.0)
Expand Down
28 changes: 17 additions & 11 deletions python/sparktk/models/_selection/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sparktk import TkContext
from sparktk.frame.ops.classification_metrics_value import ClassificationMetricsValue
from collections import namedtuple
from sparktk.arguments import extract_call, validate_call
from sparktk.arguments import extract_call, validate_call, require_type, affirm_type, value_error
from sparktk.frame.frame import Frame
from sparktk import arguments

Expand Down Expand Up @@ -156,11 +156,18 @@ def grid_search(train_frame, test_frame, train_descriptors, tc= TkContext.implic
"""

#validate input
# validate input
TkContext.validate(tc)
if not isinstance(train_descriptors, list):
train_descriptors = [train_descriptors]
descriptors = [TrainDescriptor(x[0], x[1]) for x in train_descriptors if not isinstance(x, TrainDescriptor)]
descriptors = affirm_type.list_of_anything(train_descriptors, "train_descriptors")
for i in xrange(len(descriptors)):
item = descriptors[i]
if not isinstance(item, TrainDescriptor):
require_type(tuple, item, "item", "grid_search needs a list of items which are either of type TrainDescriptor or tuples of (model, train_kwargs)")
if len(item) != 2:
raise value_error("list requires tuples of len 2", item, "item in train_descriptors")
if not hasattr(item[0], 'train'):
raise value_error("first item in tuple needs to be a object with a 'train' function", item, "item in train_descriptors")
descriptors[i] = TrainDescriptor(item[0], item[1])

arguments.require_type(Frame, train_frame, "frame")
arguments.require_type(Frame, test_frame, "frame")
Expand Down Expand Up @@ -258,15 +265,14 @@ def _create_metric_sum(a, b):


class TrainDescriptor(object):
"""
Class that separates the model type and args from the input and handles the representation.
"""
"""Describes a train operation: a model type and the arguments for its train method"""

def __init__(self, model_type, kwargs):
"""
Initializes the model_type and model's arguments
:param model_type: The name of the model
:param kwargs: The list of model parameters
Creates a TrainDescriptor
:param model_type: type object representing the model in question
:param kwargs: dict of key-value-pairs holding values for the train method's parameters
"""
self.model_type = model_type
self.kwargs = kwargs
Expand Down

0 comments on commit ff560fe

Please sign in to comment.