diff options
Diffstat (limited to 'gensvm/cython_wrapper/wrapper.pyx')
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 6d5bbae..e138620 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -133,6 +133,45 @@ def predict_wrap( return predictions + +def predict_kernels_wrap( + np.ndarray[np.float64_t, ndim=2, mode='c'] Xtest, + np.ndarray[np.float64_t, ndim=2, mode='c'] Xtrain, + np.ndarray[np.float64_t, ndim=2, mode='c'] V, + long n_class, + int kernel_idx, + double gamma, + double coef, + double degree, + double kernel_eigen_cutoff + ): + """ + Compute predictions for nonlinear GenSVM. Calls the C helper function + "gensvm_predict_kernels", which in turn calls the appropriate library + functions. + """ + + cdef long n_obs_test + cdef long n_obs_train + cdef long n_var + cdef long V_rows = V.shape[0] + cdef long V_cols = V.shape[1] + + n_obs_test = Xtest.shape[0] + n_obs_train = Xtrain.shape[0] + n_var = Xtrain.shape[1] + + cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions + predictions = np.empty((n_obs_test, ), dtype=np.int) + + with nogil: + gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows, + V_cols, n_obs_train, n_obs_test, n_var, n_class, kernel_idx, + gamma, coef, degree, kernel_eigen_cutoff, predictions.data) + + return predictions + + def grid_wrap( np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int_t, ndim=1, mode='c'] y, |
