aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_core.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/test/test_core.py b/test/test_core.py
index 63c172d..ebea89b 100644
--- a/test/test_core.py
+++ b/test/test_core.py
@@ -322,6 +322,23 @@ class GenSVMTestCase(unittest.TestCase):
clf = GenSVM(kernel="rbf")
clf.fit(X, y, sample_weight=weights)
+ def test_support_vectors(self):
+ """ GENSVM: Test getting the support vectors """
+ X = np.array(
+ [[-2, 2], [-2, -2], [-1, 0], [1, 0], [2, 2], [2, -2]] # SV # SV
+ )
+ y = np.array([0, 0, 0, 1, 1, 1])
+ clf = GenSVM()
+ clf.fit(X, y)
+ SVs = clf.SVs_
+ self.assertEqual(SVs[0], 0)
+ self.assertEqual(SVs[1], 0)
+ self.assertEqual(SVs[2], 1)
+ self.assertEqual(SVs[3], 1)
+ self.assertEqual(SVs[4], 0)
+ self.assertEqual(SVs[5], 0)
+ self.assertEqual(clf.n_support_, 2)
+
if __name__ == "__main__":
unittest.main()