diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:26:27 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:26:27 -0500 |
| commit | 10a77e16a4176d6f015d408dd2d162fc922ebff4 (patch) | |
| tree | 6f3a6be596fbeed6bf80537827197344a52d9e9e /test | |
| parent | update submodule (diff) | |
| download | pygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.tar.gz pygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.zip | |
Add predefined parameter grids
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_gridsearch.py | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index 16f4a3f..f07e064 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -15,11 +15,7 @@ from sklearn.datasets import load_iris, load_digits from sklearn.model_selection import train_test_split from sklearn.preprocessing import maxabs_scale -from gensvm.gridsearch import ( - GenSVMGridSearchCV, - _validate_param_grid, - load_default_grid, -) +from gensvm.gridsearch import GenSVMGridSearchCV, _validate_param_grid class GenSVMGridSearchCVTestCase(unittest.TestCase): @@ -207,3 +203,39 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): } gg = GenSVMGridSearchCV(pg, verbose=True) gg.fit(Xs, ys) + + def test_gridsearch_tiny(self): + """ GENSVM_GRID: Test with tiny grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="tiny") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_small(self): + """ GENSVM_GRID: Test with small grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="small") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_full(self): + """ GENSVM_GRID: Test with full grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="full") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.90) |
