diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:27:44 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:27:44 -0500 |
| commit | 4328e3f2e02afb98ab5357a40b201cf70413f29f (patch) | |
| tree | 61b40c5e70da8ce781b8dcd51d58e68afed99132 /test | |
| parent | Simplify citation so it's easier to use correctly (diff) | |
| parent | Add predefined parameter grids (diff) | |
| download | pygensvm-4328e3f2e02afb98ab5357a40b201cf70413f29f.tar.gz pygensvm-4328e3f2e02afb98ab5357a40b201cf70413f29f.zip | |
Merge branch 'param_grids'
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_gridsearch.py | 42 |
1 files changed, 37 insertions, 5 deletions
diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index 16f4a3f..f07e064 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -15,11 +15,7 @@ from sklearn.datasets import load_iris, load_digits from sklearn.model_selection import train_test_split from sklearn.preprocessing import maxabs_scale -from gensvm.gridsearch import ( - GenSVMGridSearchCV, - _validate_param_grid, - load_default_grid, -) +from gensvm.gridsearch import GenSVMGridSearchCV, _validate_param_grid class GenSVMGridSearchCVTestCase(unittest.TestCase): @@ -207,3 +203,39 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): } gg = GenSVMGridSearchCV(pg, verbose=True) gg.fit(Xs, ys) + + def test_gridsearch_tiny(self): + """ GENSVM_GRID: Test with tiny grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="tiny") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_small(self): + """ GENSVM_GRID: Test with small grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="small") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.95) + + def test_gridsearch_full(self): + """ GENSVM_GRID: Test with full grid """ + X, y = load_iris(return_X_y=True) + X = maxabs_scale(X) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clf = GenSVMGridSearchCV(param_grid="full") + clf.fit(X_train, y_train) + + score = clf.score(X_test, y_test) + self.assertGreaterEqual(score, 0.90) |
