aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 13:26:27 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 13:26:27 -0500
commit10a77e16a4176d6f015d408dd2d162fc922ebff4 (patch)
tree6f3a6be596fbeed6bf80537827197344a52d9e9e /test
parentupdate submodule (diff)
downloadpygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.tar.gz
pygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.zip
Add predefined parameter grids
Diffstat (limited to 'test')
-rw-r--r--test/test_gridsearch.py42
1 files changed, 37 insertions, 5 deletions
diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py
index 16f4a3f..f07e064 100644
--- a/test/test_gridsearch.py
+++ b/test/test_gridsearch.py
@@ -15,11 +15,7 @@ from sklearn.datasets import load_iris, load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import maxabs_scale
-from gensvm.gridsearch import (
- GenSVMGridSearchCV,
- _validate_param_grid,
- load_default_grid,
-)
+from gensvm.gridsearch import GenSVMGridSearchCV, _validate_param_grid
class GenSVMGridSearchCVTestCase(unittest.TestCase):
@@ -207,3 +203,39 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase):
}
gg = GenSVMGridSearchCV(pg, verbose=True)
gg.fit(Xs, ys)
+
+ def test_gridsearch_tiny(self):
+ """ GENSVM_GRID: Test with tiny grid """
+ X, y = load_iris(return_X_y=True)
+ X = maxabs_scale(X)
+ X_train, X_test, y_train, y_test = train_test_split(X, y)
+
+ clf = GenSVMGridSearchCV(param_grid="tiny")
+ clf.fit(X_train, y_train)
+
+ score = clf.score(X_test, y_test)
+ self.assertGreaterEqual(score, 0.95)
+
+ def test_gridsearch_small(self):
+ """ GENSVM_GRID: Test with small grid """
+ X, y = load_iris(return_X_y=True)
+ X = maxabs_scale(X)
+ X_train, X_test, y_train, y_test = train_test_split(X, y)
+
+ clf = GenSVMGridSearchCV(param_grid="small")
+ clf.fit(X_train, y_train)
+
+ score = clf.score(X_test, y_test)
+ self.assertGreaterEqual(score, 0.95)
+
+ def test_gridsearch_full(self):
+ """ GENSVM_GRID: Test with full grid """
+ X, y = load_iris(return_X_y=True)
+ X = maxabs_scale(X)
+ X_train, X_test, y_train, y_test = train_test_split(X, y)
+
+ clf = GenSVMGridSearchCV(param_grid="full")
+ clf.fit(X_train, y_train)
+
+ score = clf.score(X_test, y_test)
+ self.assertGreaterEqual(score, 0.90)