aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 22:28:49 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 22:28:49 -0500
commit3ed51bb4ac5db6dfec92d10d6e302381c27849c8 (patch)
tree24d1df8d8ef4bdf003ca35ae310a5804c28b455e /test
parentBugfix and test for predict method of GridSearch (diff)
downloadpygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.tar.gz
pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.zip
Add support for specifying sample weights (fixes #2)
Diffstat (limited to 'test')
-rw-r--r--test/test_core.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test_core.py b/test/test_core.py
index dcc72bf..af4bb2a 100644
--- a/test/test_core.py
+++ b/test/test_core.py
@@ -291,6 +291,20 @@ class GenSVMTestCase(unittest.TestCase):
self.assertLess(abs(V[3, 1] - 0.6547924025329720), eps)
self.assertLess(abs(V[3, 2] - -0.6773346708737853), eps)
+ def test_fit_with_weights(self):
+ """ GENSVM: Test fit with sample weights """
+ X, y = load_iris(return_X_y=True)
+ weights = np.random.random((X.shape[0],))
+ clf = GenSVM()
+ clf.fit(X, y, sample_weight=weights)
+ # with seeding
+ V = clf.combined_coef_
+ weights = np.random.random((X.shape[0],))
+ clf.fit(X, y, sample_weight=weights, seed_V=V)
+
+ clf = GenSVM(kernel="rbf")
+ clf.fit(X, y, sample_weight=weights)
+
if __name__ == "__main__":
unittest.main()