diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 14:52:55 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 14:52:55 +0000 |
| commit | f695999cb24a91b39b14e48c03b28aec78bb95c1 (patch) | |
| tree | 457c3026659ecaab11e7e92b02c4f0d559daa0f0 /gensvm | |
| parent | Document the random_state argument (diff) | |
| download | pygensvm-f695999cb24a91b39b14e48c03b28aec78bb95c1.tar.gz pygensvm-f695999cb24a91b39b14e48c03b28aec78bb95c1.zip | |
Add code for prediction with kernels
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 36 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 2 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 39 |
3 files changed, 74 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index a2c4cd0..ce416d0 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -343,20 +343,50 @@ class GenSVM(BaseEstimator, ClassifierMixin): ) return self - def predict(self, X): + def predict(self, X, trainX=None): """Predict the class labels on the given data Parameters ---------- - X : array, shape = [n_samples, n_features] + X : array, shape = [n_test_samples, n_features] + Data for which to predict the labels + + trainX : array, shape = [n_train_samples, n_features] + Only for nonlinear prediction with kernels: the training data used + to train the model. Returns ------- y_pred : array, shape = (n_samples, ) """ + + if (not self.kernel == "linear") and trainX is None: + raise ValueError( + "Training data must be provided with nonlinear prediction" + ) + if not trainX is None and not X.shape[1] == trainX.shape[1]: + raise ValueError( + "Test and training data should have the same number of features" + ) + V = self.combined_coef_ - predictions = wrapper.predict_wrap(X, V) + if self.kernel == "linear": + predictions = wrapper.predict_wrap(X, V) + else: + n_class = len(self.encoder.classes_) + kernel_idx = wrapper.GENSVM_KERNEL_TYPES.index(self.kernel) + predictions = wrapper.predict_kernels_wrap( + X, + trainX, + V, + n_class, + kernel_idx, + self.gamma, + self.coef, + self.degree, + self.kernel_eigen_cutoff, + ) # Transform the classes back to the original form predictions -= 1 diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index 441c15b..6a896aa 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -129,6 +129,8 @@ cdef extern from "gensvm_helper.c": void free_data(GenData *) void set_verbosity(int) void gensvm_predict(char *, char *, long, long, long, char *) nogil + void gensvm_predict_kernels(char *, char *, char *, long, long, long, + long, long, long, int, double, double, double, double, char *) nogil void gensvm_train_q_helper(GenQueue *, char *, int) nogil void set_queue(GenQueue *, long, GenTask **) double get_task_duration(GenTask *) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 6d5bbae..e138620 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -133,6 +133,45 @@ def predict_wrap( return predictions + +def predict_kernels_wrap( + np.ndarray[np.float64_t, ndim=2, mode='c'] Xtest, + np.ndarray[np.float64_t, ndim=2, mode='c'] Xtrain, + np.ndarray[np.float64_t, ndim=2, mode='c'] V, + long n_class, + int kernel_idx, + double gamma, + double coef, + double degree, + double kernel_eigen_cutoff + ): + """ + Compute predictions for nonlinear GenSVM. Calls the C helper function + "gensvm_predict_kernels", which in turn calls the appropriate library + functions. + """ + + cdef long n_obs_test + cdef long n_obs_train + cdef long n_var + cdef long V_rows = V.shape[0] + cdef long V_cols = V.shape[1] + + n_obs_test = Xtest.shape[0] + n_obs_train = Xtrain.shape[0] + n_var = Xtrain.shape[1] + + cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions + predictions = np.empty((n_obs_test, ), dtype=np.int) + + with nogil: + gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows, + V_cols, n_obs_train, n_obs_test, n_var, n_class, kernel_idx, + gamma, coef, degree, kernel_eigen_cutoff, predictions.data) + + return predictions + + def grid_wrap( np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int_t, ndim=1, mode='c'] y, |
