aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 12:24:52 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 12:24:52 -0500
commitc5515107b8e3ead4724096d5c381608c957cc4df (patch)
tree171f55936cea662e6302949dff2d0874708c1c89 /test
parentFormatting (diff)
downloadpygensvm-c5515107b8e3ead4724096d5c381608c957cc4df.tar.gz
pygensvm-c5515107b8e3ead4724096d5c381608c957cc4df.zip
Add test for nonlinear training
Diffstat (limited to 'test')
-rw-r--r--test/test_core.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_core.py b/test/test_core.py
index 4be7df4..68d4b1e 100644
--- a/test/test_core.py
+++ b/test/test_core.py
@@ -180,6 +180,22 @@ class GenSVMTestCase(unittest.TestCase):
label_set = set(labels)
self.assertTrue(pred_set.issubset(label_set))
+ def test_fit_nonlinear(self):
+ """ GENSVM: Fit and predict with nonlinear kernel """
+ data = load_iris()
+ X = data.data
+ y = data.target_names[data.target]
+
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, random_state=123
+ )
+ clf = GenSVM(kernel="rbf", gamma=10, max_iter=5000, random_state=123)
+ clf.fit(X_train, y_train)
+
+ pred = clf.predict(X_test, trainX=X_train)
+ self.assertTrue(set(pred).issubset(set(['versicolor', 'virginica',
+ 'setosa'])))
+
def test_fit_with_seed(self):
""" GENSVM: Test fit with seeding """