aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 19:15:18 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 19:15:18 -0500
commit4efec8aa479be05cd548f91458ac26a4ecff34f2 (patch)
treeb61719ea7ac3bb8c6ca5c7aa242d00e0afb07d94 /gensvm
parentCheck to make sure predict arrays are contiguous (diff)
downloadpygensvm-4efec8aa479be05cd548f91458ac26a4ecff34f2.tar.gz
pygensvm-4efec8aa479be05cd548f91458ac26a4ecff34f2.zip
Bugfix and test for predict method of GridSearch
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/gridsearch.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py
index ed12b97..24e0a74 100644
--- a/gensvm/gridsearch.py
+++ b/gensvm/gridsearch.py
@@ -598,6 +598,14 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
]
if self.refit:
+ # when we used a nonlinear kernel and specified no gamma, then
+ # gamma='auto' was used. We need to save the actual numerical value
+ # for use in the predict method later on, so we extract that here.
+ if (
+ not self.best_params_["kernel"] == "linear"
+ and not "gamma" in self.best_params_
+ ):
+ self.best_params_["gamma"] = 1. / X.shape[1]
self.best_estimator_ = GenSVM(**self.best_params_)
# y_orig because GenSVM fit must know the conversion for predict to
# work correctly
@@ -639,7 +647,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
self.multimetric_,
)
- def predict(self, X):
+ def predict(self, X, trainX=None):
"""Predict the class labels on the test data
Parameters
@@ -648,6 +656,10 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
Test data, where n_samples is the number of observations and
n_features is the number of features.
+ trainX : array, shape = [n_train_samples, n_features]
+ Only for nonlinear prediction with kernels: the training data used
+ to train the model.
+
Returns
-------
y_pred : array-like, shape = (n_samples, )
@@ -655,7 +667,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
"""
_skl_check_is_fitted(self, "predict", self.refit)
- return self.best_estimator_.predict(X)
+ return self.best_estimator_.predict(X, trainX=trainX)
def load_grid_tiny():