aboutsummaryrefslogtreecommitdiff
path: root/gensvm/cython_wrapper/wrapper.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/cython_wrapper/wrapper.pyx')
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx8
1 files changed, 6 insertions, 2 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 0858103..009e70b 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -99,18 +99,22 @@ def train_wrap(
V = np.empty((n_var+1, n_class-1))
copy_V(V.data, model)
+ # get the support vectors
+ cdef np.ndarray[np.int32_t, ndim=1, mode='c'] SVs
+ SVs = np.empty((n_obs, ), dtype=np.int32)
+ get_SVs(model, SVs.data)
+
# get other results from model
iter_count = get_iter_count(model)
training_error = get_training_error(model)
fit_status = get_status(model)
- n_SV = gensvm_num_sv(model)
# free model and data
gensvm_free_model(model);
gensvm_free_model(seed_model)
free_data(data);
- return (V, n_SV, iter_count, training_error, fit_status)
+ return (V, SVs, iter_count, training_error, fit_status)
def predict_wrap(