aboutsummaryrefslogtreecommitdiff
path: root/gensvm/cython_wrapper/wrapper.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/cython_wrapper/wrapper.pyx')
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx39
1 files changed, 39 insertions, 0 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 6d5bbae..e138620 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -133,6 +133,45 @@ def predict_wrap(
return predictions
+
+def predict_kernels_wrap(
+ np.ndarray[np.float64_t, ndim=2, mode='c'] Xtest,
+ np.ndarray[np.float64_t, ndim=2, mode='c'] Xtrain,
+ np.ndarray[np.float64_t, ndim=2, mode='c'] V,
+ long n_class,
+ int kernel_idx,
+ double gamma,
+ double coef,
+ double degree,
+ double kernel_eigen_cutoff
+ ):
+ """
+ Compute predictions for nonlinear GenSVM. Calls the C helper function
+ "gensvm_predict_kernels", which in turn calls the appropriate library
+ functions.
+ """
+
+ cdef long n_obs_test
+ cdef long n_obs_train
+ cdef long n_var
+ cdef long V_rows = V.shape[0]
+ cdef long V_cols = V.shape[1]
+
+ n_obs_test = Xtest.shape[0]
+ n_obs_train = Xtrain.shape[0]
+ n_var = Xtrain.shape[1]
+
+ cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions
+ predictions = np.empty((n_obs_test, ), dtype=np.int)
+
+ with nogil:
+ gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows,
+ V_cols, n_obs_train, n_obs_test, n_var, n_class, kernel_idx,
+ gamma, coef, degree, kernel_eigen_cutoff, predictions.data)
+
+ return predictions
+
+
def grid_wrap(
np.ndarray[np.float64_t, ndim=2, mode='c'] X,
np.ndarray[np.int_t, ndim=1, mode='c'] y,