diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-05-30 17:31:17 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-05-30 17:31:17 +0100 |
| commit | 9ad95f7075cf6e03b27207f004fce36909ebbb53 (patch) | |
| tree | 27b6dd7425a19c1cc2d99dc046f46cb9ec55863c /gensvm/cython_wrapper | |
| parent | use green for testing instead of nose (diff) | |
| download | pygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.tar.gz pygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.zip | |
Add support for returning the support vectors (#5)
Diffstat (limited to 'gensvm/cython_wrapper')
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 5 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 8 |
2 files changed, 7 insertions, 6 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index d19d442..0adf7af 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -94,10 +94,6 @@ cdef extern from "gensvm_train.h": void gensvm_train(GenModel *, GenData *, GenModel *) nogil -cdef extern from "gensvm_sv.h": - - long gensvm_num_sv(GenModel *) - cdef extern from "gensvm_queue.h": cdef struct GenQueue: @@ -121,6 +117,7 @@ cdef extern from "gensvm_helper.c": double, double, int, double, double, double, long) char_const_ptr check_model(GenModel *) void copy_V(void *, GenModel *) + void get_SVs(GenModel *, void *) long get_iter_count(GenModel *) double get_training_error(GenModel *) int get_status(GenModel *) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 0858103..009e70b 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -99,18 +99,22 @@ def train_wrap( V = np.empty((n_var+1, n_class-1)) copy_V(V.data, model) + # get the support vectors + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] SVs + SVs = np.empty((n_obs, ), dtype=np.int32) + get_SVs(model, SVs.data) + # get other results from model iter_count = get_iter_count(model) training_error = get_training_error(model) fit_status = get_status(model) - n_SV = gensvm_num_sv(model) # free model and data gensvm_free_model(model); gensvm_free_model(seed_model) free_data(data); - return (V, n_SV, iter_count, training_error, fit_status) + return (V, SVs, iter_count, training_error, fit_status) def predict_wrap( |
