diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
| commit | 3ed51bb4ac5db6dfec92d10d6e302381c27849c8 (patch) | |
| tree | 24d1df8d8ef4bdf003ca35ae310a5804c28b455e /test | |
| parent | Bugfix and test for predict method of GridSearch (diff) | |
| download | pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.tar.gz pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.zip | |
Add support for specifying sample weights (fixes #2)
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_core.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test_core.py b/test/test_core.py index dcc72bf..af4bb2a 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -291,6 +291,20 @@ class GenSVMTestCase(unittest.TestCase): self.assertLess(abs(V[3, 1] - 0.6547924025329720), eps) self.assertLess(abs(V[3, 2] - -0.6773346708737853), eps) + def test_fit_with_weights(self): + """ GENSVM: Test fit with sample weights """ + X, y = load_iris(return_X_y=True) + weights = np.random.random((X.shape[0],)) + clf = GenSVM() + clf.fit(X, y, sample_weight=weights) + # with seeding + V = clf.combined_coef_ + weights = np.random.random((X.shape[0],)) + clf.fit(X, y, sample_weight=weights, seed_V=V) + + clf = GenSVM(kernel="rbf") + clf.fit(X, y, sample_weight=weights) + if __name__ == "__main__": unittest.main() |
