diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:24:52 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:24:52 -0500 |
| commit | c5515107b8e3ead4724096d5c381608c957cc4df (patch) | |
| tree | 171f55936cea662e6302949dff2d0874708c1c89 /test | |
| parent | Formatting (diff) | |
| download | pygensvm-c5515107b8e3ead4724096d5c381608c957cc4df.tar.gz pygensvm-c5515107b8e3ead4724096d5c381608c957cc4df.zip | |
Add test for nonlinear training
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_core.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_core.py b/test/test_core.py index 4be7df4..68d4b1e 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -180,6 +180,22 @@ class GenSVMTestCase(unittest.TestCase): label_set = set(labels) self.assertTrue(pred_set.issubset(label_set)) + def test_fit_nonlinear(self): + """ GENSVM: Fit and predict with nonlinear kernel """ + data = load_iris() + X = data.data + y = data.target_names[data.target] + + X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=123 + ) + clf = GenSVM(kernel="rbf", gamma=10, max_iter=5000, random_state=123) + clf.fit(X_train, y_train) + + pred = clf.predict(X_test, trainX=X_train) + self.assertTrue(set(pred).issubset(set(['versicolor', 'virginica', + 'setosa']))) + def test_fit_with_seed(self): """ GENSVM: Test fit with seeding """ |
