diff options
| -rw-r--r-- | gensvm/core.py | 9 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 5 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 8 | ||||
| m--------- | src/gensvm | 0 | ||||
| -rw-r--r-- | test/test_core.py | 17 |
5 files changed, 30 insertions, 9 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index f0e7820..bfd5d9a 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -52,7 +52,7 @@ def _fit_gensvm( weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights] # run the actual training - raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap( + raw_coef_, SVs_, n_iter_, training_error_, status_ = wrapper.train_wrap( X, y, n_class, @@ -91,8 +91,9 @@ def _fit_gensvm( coef_ = raw_coef_[1:, :] intercept_ = raw_coef_[0, :] + n_SV_ = sum(SVs_) - return coef_, intercept_, n_iter_, n_SV_ + return coef_, intercept_, n_iter_, n_SV_, SVs_ class GenSVM(BaseEstimator, ClassifierMixin): @@ -177,6 +178,8 @@ class GenSVM(BaseEstimator, ClassifierMixin): n_support_ : int The number of support vectors that were found + SVs_ : array, shape = [n_observations, ] + Index vector that marks the support vectors (1 = SV, 0 = no SV) See Also -------- @@ -351,7 +354,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): ) ) - self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm( + self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm( X, y, n_class, diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index d19d442..0adf7af 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -94,10 +94,6 @@ cdef extern from "gensvm_train.h": void gensvm_train(GenModel *, GenData *, GenModel *) nogil -cdef extern from "gensvm_sv.h": - - long gensvm_num_sv(GenModel *) - cdef extern from "gensvm_queue.h": cdef struct GenQueue: @@ -121,6 +117,7 @@ cdef extern from "gensvm_helper.c": double, double, int, double, double, double, long) char_const_ptr check_model(GenModel *) void copy_V(void *, GenModel *) + void get_SVs(GenModel *, void *) long get_iter_count(GenModel *) double get_training_error(GenModel *) int get_status(GenModel *) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 0858103..009e70b 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -99,18 +99,22 @@ def train_wrap( V = np.empty((n_var+1, n_class-1)) copy_V(V.data, model) + # get the support vectors + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] SVs + SVs = np.empty((n_obs, ), dtype=np.int32) + get_SVs(model, SVs.data) + # get other results from model iter_count = get_iter_count(model) training_error = get_training_error(model) fit_status = get_status(model) - n_SV = gensvm_num_sv(model) # free model and data gensvm_free_model(model); gensvm_free_model(seed_model) free_data(data); - return (V, n_SV, iter_count, training_error, fit_status) + return (V, SVs, iter_count, training_error, fit_status) def predict_wrap( diff --git a/src/gensvm b/src/gensvm -Subproject cc4bf8ef13b90af9ab9169c428850cad80f469b +Subproject 1257d3702aaf9d601eab4dace143550fde4928a diff --git a/test/test_core.py b/test/test_core.py index 63c172d..ebea89b 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -322,6 +322,23 @@ class GenSVMTestCase(unittest.TestCase): clf = GenSVM(kernel="rbf") clf.fit(X, y, sample_weight=weights) + def test_support_vectors(self): + """ GENSVM: Test getting the support vectors """ + X = np.array( + [[-2, 2], [-2, -2], [-1, 0], [1, 0], [2, 2], [2, -2]] # SV # SV + ) + y = np.array([0, 0, 0, 1, 1, 1]) + clf = GenSVM() + clf.fit(X, y) + SVs = clf.SVs_ + self.assertEqual(SVs[0], 0) + self.assertEqual(SVs[1], 0) + self.assertEqual(SVs[2], 1) + self.assertEqual(SVs[3], 1) + self.assertEqual(SVs[4], 0) + self.assertEqual(SVs[5], 0) + self.assertEqual(clf.n_support_, 2) + if __name__ == "__main__": unittest.main() |
