diff options
Diffstat (limited to 'gensvm/core.py')
| -rw-r--r-- | gensvm/core.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index a2c4cd0..ce416d0 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -343,20 +343,50 @@ class GenSVM(BaseEstimator, ClassifierMixin): ) return self - def predict(self, X): + def predict(self, X, trainX=None): """Predict the class labels on the given data Parameters ---------- - X : array, shape = [n_samples, n_features] + X : array, shape = [n_test_samples, n_features] + Data for which to predict the labels + + trainX : array, shape = [n_train_samples, n_features] + Only for nonlinear prediction with kernels: the training data used + to train the model. Returns ------- y_pred : array, shape = (n_samples, ) """ + + if (not self.kernel == "linear") and trainX is None: + raise ValueError( + "Training data must be provided with nonlinear prediction" + ) + if not trainX is None and not X.shape[1] == trainX.shape[1]: + raise ValueError( + "Test and training data should have the same number of features" + ) + V = self.combined_coef_ - predictions = wrapper.predict_wrap(X, V) + if self.kernel == "linear": + predictions = wrapper.predict_wrap(X, V) + else: + n_class = len(self.encoder.classes_) + kernel_idx = wrapper.GENSVM_KERNEL_TYPES.index(self.kernel) + predictions = wrapper.predict_kernels_wrap( + X, + trainX, + V, + n_class, + kernel_idx, + self.gamma, + self.coef, + self.degree, + self.kernel_eigen_cutoff, + ) # Transform the classes back to the original form predictions -= 1 |
