aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-05-30 17:34:05 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-05-30 17:34:05 +0100
commit06d9236418a01157056c20c416c8e25cbc6eaf73 (patch)
tree6539d73ccb6b3708110c78158b7a1cfd551020fd /gensvm
parentbump version (diff)
parentAdd support for returning the support vectors (#5) (diff)
downloadpygensvm-06d9236418a01157056c20c416c8e25cbc6eaf73.tar.gz
pygensvm-06d9236418a01157056c20c416c8e25cbc6eaf73.zip
Merge branch 'master' of github.com:GjjvdBurg/PyGenSVM
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/core.py9
-rw-r--r--gensvm/cython_wrapper/wrapper.pxd5
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx8
3 files changed, 13 insertions, 9 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index f0e7820..bfd5d9a 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -52,7 +52,7 @@ def _fit_gensvm(
weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights]
# run the actual training
- raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap(
+ raw_coef_, SVs_, n_iter_, training_error_, status_ = wrapper.train_wrap(
X,
y,
n_class,
@@ -91,8 +91,9 @@ def _fit_gensvm(
coef_ = raw_coef_[1:, :]
intercept_ = raw_coef_[0, :]
+ n_SV_ = sum(SVs_)
- return coef_, intercept_, n_iter_, n_SV_
+ return coef_, intercept_, n_iter_, n_SV_, SVs_
class GenSVM(BaseEstimator, ClassifierMixin):
@@ -177,6 +178,8 @@ class GenSVM(BaseEstimator, ClassifierMixin):
n_support_ : int
The number of support vectors that were found
+ SVs_ : array, shape = [n_observations, ]
+ Index vector that marks the support vectors (1 = SV, 0 = no SV)
See Also
--------
@@ -351,7 +354,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
)
)
- self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm(
+ self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm(
X,
y,
n_class,
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd
index d19d442..0adf7af 100644
--- a/gensvm/cython_wrapper/wrapper.pxd
+++ b/gensvm/cython_wrapper/wrapper.pxd
@@ -94,10 +94,6 @@ cdef extern from "gensvm_train.h":
void gensvm_train(GenModel *, GenData *, GenModel *) nogil
-cdef extern from "gensvm_sv.h":
-
- long gensvm_num_sv(GenModel *)
-
cdef extern from "gensvm_queue.h":
cdef struct GenQueue:
@@ -121,6 +117,7 @@ cdef extern from "gensvm_helper.c":
double, double, int, double, double, double, long)
char_const_ptr check_model(GenModel *)
void copy_V(void *, GenModel *)
+ void get_SVs(GenModel *, void *)
long get_iter_count(GenModel *)
double get_training_error(GenModel *)
int get_status(GenModel *)
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 0858103..009e70b 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -99,18 +99,22 @@ def train_wrap(
V = np.empty((n_var+1, n_class-1))
copy_V(V.data, model)
+ # get the support vectors
+ cdef np.ndarray[np.int32_t, ndim=1, mode='c'] SVs
+ SVs = np.empty((n_obs, ), dtype=np.int32)
+ get_SVs(model, SVs.data)
+
# get other results from model
iter_count = get_iter_count(model)
training_error = get_training_error(model)
fit_status = get_status(model)
- n_SV = gensvm_num_sv(model)
# free model and data
gensvm_free_model(model);
gensvm_free_model(seed_model)
free_data(data);
- return (V, n_SV, iter_count, training_error, fit_status)
+ return (V, SVs, iter_count, training_error, fit_status)
def predict_wrap(