aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-01-15 14:52:55 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-01-15 14:52:55 +0000
commitf695999cb24a91b39b14e48c03b28aec78bb95c1 (patch)
tree457c3026659ecaab11e7e92b02c4f0d559daa0f0 /gensvm
parentDocument the random_state argument (diff)
downloadpygensvm-f695999cb24a91b39b14e48c03b28aec78bb95c1.tar.gz
pygensvm-f695999cb24a91b39b14e48c03b28aec78bb95c1.zip
Add code for prediction with kernels
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/core.py36
-rw-r--r--gensvm/cython_wrapper/wrapper.pxd2
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx39
3 files changed, 74 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index a2c4cd0..ce416d0 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -343,20 +343,50 @@ class GenSVM(BaseEstimator, ClassifierMixin):
)
return self
- def predict(self, X):
+ def predict(self, X, trainX=None):
"""Predict the class labels on the given data
Parameters
----------
- X : array, shape = [n_samples, n_features]
+ X : array, shape = [n_test_samples, n_features]
+ Data for which to predict the labels
+
+ trainX : array, shape = [n_train_samples, n_features]
+ Only for nonlinear prediction with kernels: the training data used
+ to train the model.
Returns
-------
y_pred : array, shape = (n_samples, )
"""
+
+ if (not self.kernel == "linear") and trainX is None:
+ raise ValueError(
+ "Training data must be provided with nonlinear prediction"
+ )
+ if not trainX is None and not X.shape[1] == trainX.shape[1]:
+ raise ValueError(
+ "Test and training data should have the same number of features"
+ )
+
V = self.combined_coef_
- predictions = wrapper.predict_wrap(X, V)
+ if self.kernel == "linear":
+ predictions = wrapper.predict_wrap(X, V)
+ else:
+ n_class = len(self.encoder.classes_)
+ kernel_idx = wrapper.GENSVM_KERNEL_TYPES.index(self.kernel)
+ predictions = wrapper.predict_kernels_wrap(
+ X,
+ trainX,
+ V,
+ n_class,
+ kernel_idx,
+ self.gamma,
+ self.coef,
+ self.degree,
+ self.kernel_eigen_cutoff,
+ )
# Transform the classes back to the original form
predictions -= 1
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd
index 441c15b..6a896aa 100644
--- a/gensvm/cython_wrapper/wrapper.pxd
+++ b/gensvm/cython_wrapper/wrapper.pxd
@@ -129,6 +129,8 @@ cdef extern from "gensvm_helper.c":
void free_data(GenData *)
void set_verbosity(int)
void gensvm_predict(char *, char *, long, long, long, char *) nogil
+ void gensvm_predict_kernels(char *, char *, char *, long, long, long,
+ long, long, long, int, double, double, double, double, char *) nogil
void gensvm_train_q_helper(GenQueue *, char *, int) nogil
void set_queue(GenQueue *, long, GenTask **)
double get_task_duration(GenTask *)
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 6d5bbae..e138620 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -133,6 +133,45 @@ def predict_wrap(
return predictions
+
+def predict_kernels_wrap(
+ np.ndarray[np.float64_t, ndim=2, mode='c'] Xtest,
+ np.ndarray[np.float64_t, ndim=2, mode='c'] Xtrain,
+ np.ndarray[np.float64_t, ndim=2, mode='c'] V,
+ long n_class,
+ int kernel_idx,
+ double gamma,
+ double coef,
+ double degree,
+ double kernel_eigen_cutoff
+ ):
+ """
+ Compute predictions for nonlinear GenSVM. Calls the C helper function
+ "gensvm_predict_kernels", which in turn calls the appropriate library
+ functions.
+ """
+
+ cdef long n_obs_test
+ cdef long n_obs_train
+ cdef long n_var
+ cdef long V_rows = V.shape[0]
+ cdef long V_cols = V.shape[1]
+
+ n_obs_test = Xtest.shape[0]
+ n_obs_train = Xtrain.shape[0]
+ n_var = Xtrain.shape[1]
+
+ cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions
+ predictions = np.empty((n_obs_test, ), dtype=np.int)
+
+ with nogil:
+ gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows,
+ V_cols, n_obs_train, n_obs_test, n_var, n_class, kernel_idx,
+ gamma, coef, degree, kernel_eigen_cutoff, predictions.data)
+
+ return predictions
+
+
def grid_wrap(
np.ndarray[np.float64_t, ndim=2, mode='c'] X,
np.ndarray[np.int_t, ndim=1, mode='c'] y,