From 10a77e16a4176d6f015d408dd2d162fc922ebff4 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 6 Mar 2019 13:26:27 -0500 Subject: Add predefined parameter grids --- test/test_gridsearch.py | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) (limited to 'test/test_gridsearch.py') diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index 16f4a3f..f07e064 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -15,11 +15,7 @@ from sklearn.datasets import load_iris, load_digits from sklearn.model_selection import train_test_split from sklearn.preprocessing import maxabs_scale -from gensvm.gridsearch import ( - GenSVMGridSearchCV, - _validate_param_grid, - load_default_grid, -) +from gensvm.gridsearch import GenSVMGridSearchCV, _validate_param_grid class GenSVMGridSearchCVTestCase(unittest.TestCase): @@ -207,3 +203,39 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): } gg = GenSVMGridSearchCV(pg, verbose=True) gg.fit(Xs, ys) + + def test_gridsearch_tiny(self): + """ GENSVM_GRID: Test with tiny grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="tiny") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_small(self): + """ GENSVM_GRID: Test with small grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="small") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_full(self): + """ GENSVM_GRID: Test with full grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="full") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.90) -- cgit v1.2.3