aboutsummaryrefslogtreecommitdiff
path: root/gensvm/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/core.py')
-rw-r--r--gensvm/core.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index a2c4cd0..ce416d0 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -343,20 +343,50 @@ class GenSVM(BaseEstimator, ClassifierMixin):
)
return self
- def predict(self, X):
+ def predict(self, X, trainX=None):
"""Predict the class labels on the given data
Parameters
----------
- X : array, shape = [n_samples, n_features]
+ X : array, shape = [n_test_samples, n_features]
+ Data for which to predict the labels
+
+ trainX : array, shape = [n_train_samples, n_features]
+ Only for nonlinear prediction with kernels: the training data used
+ to train the model.
Returns
-------
y_pred : array, shape = (n_samples, )
"""
+
+ if (not self.kernel == "linear") and trainX is None:
+ raise ValueError(
+ "Training data must be provided with nonlinear prediction"
+ )
+ if not trainX is None and not X.shape[1] == trainX.shape[1]:
+ raise ValueError(
+ "Test and training data should have the same number of features"
+ )
+
V = self.combined_coef_
- predictions = wrapper.predict_wrap(X, V)
+ if self.kernel == "linear":
+ predictions = wrapper.predict_wrap(X, V)
+ else:
+ n_class = len(self.encoder.classes_)
+ kernel_idx = wrapper.GENSVM_KERNEL_TYPES.index(self.kernel)
+ predictions = wrapper.predict_kernels_wrap(
+ X,
+ trainX,
+ V,
+ n_class,
+ kernel_idx,
+ self.gamma,
+ self.coef,
+ self.degree,
+ self.kernel_eigen_cutoff,
+ )
# Transform the classes back to the original form
predictions -= 1