aboutsummaryrefslogtreecommitdiff
path: root/gensvm/cython_wrapper
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-05-30 17:31:17 +0100
committerGitHub <noreply@github.com>2019-05-30 17:31:17 +0100
commit9ad95f7075cf6e03b27207f004fce36909ebbb53 (patch)
tree27b6dd7425a19c1cc2d99dc046f46cb9ec55863c /gensvm/cython_wrapper
parentuse green for testing instead of nose (diff)
downloadpygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.tar.gz
pygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.zip
Add support for returning the support vectors (#5)
Diffstat (limited to 'gensvm/cython_wrapper')
-rw-r--r--gensvm/cython_wrapper/wrapper.pxd5
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx8
2 files changed, 7 insertions, 6 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd
index d19d442..0adf7af 100644
--- a/gensvm/cython_wrapper/wrapper.pxd
+++ b/gensvm/cython_wrapper/wrapper.pxd
@@ -94,10 +94,6 @@ cdef extern from "gensvm_train.h":
void gensvm_train(GenModel *, GenData *, GenModel *) nogil
-cdef extern from "gensvm_sv.h":
-
- long gensvm_num_sv(GenModel *)
-
cdef extern from "gensvm_queue.h":
cdef struct GenQueue:
@@ -121,6 +117,7 @@ cdef extern from "gensvm_helper.c":
double, double, int, double, double, double, long)
char_const_ptr check_model(GenModel *)
void copy_V(void *, GenModel *)
+ void get_SVs(GenModel *, void *)
long get_iter_count(GenModel *)
double get_training_error(GenModel *)
int get_status(GenModel *)
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 0858103..009e70b 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -99,18 +99,22 @@ def train_wrap(
V = np.empty((n_var+1, n_class-1))
copy_V(V.data, model)
+ # get the support vectors
+ cdef np.ndarray[np.int32_t, ndim=1, mode='c'] SVs
+ SVs = np.empty((n_obs, ), dtype=np.int32)
+ get_SVs(model, SVs.data)
+
# get other results from model
iter_count = get_iter_count(model)
training_error = get_training_error(model)
fit_status = get_status(model)
- n_SV = gensvm_num_sv(model)
# free model and data
gensvm_free_model(model);
gensvm_free_model(seed_model)
free_data(data);
- return (V, n_SV, iter_count, training_error, fit_status)
+ return (V, SVs, iter_count, training_error, fit_status)
def predict_wrap(