From 4efec8aa479be05cd548f91458ac26a4ecff34f2 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 6 Mar 2019 19:15:18 -0500 Subject: Bugfix and test for predict method of GridSearch --- gensvm/gridsearch.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'gensvm') diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index ed12b97..24e0a74 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -598,6 +598,14 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): ] if self.refit: + # when we used a nonlinear kernel and specified no gamma, then + # gamma='auto' was used. We need to save the actual numerical value + # for use in the predict method later on, so we extract that here. + if ( + not self.best_params_["kernel"] == "linear" + and not "gamma" in self.best_params_ + ): + self.best_params_["gamma"] = 1. / X.shape[1] self.best_estimator_ = GenSVM(**self.best_params_) # y_orig because GenSVM fit must know the conversion for predict to # work correctly @@ -639,7 +647,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.multimetric_, ) - def predict(self, X): + def predict(self, X, trainX=None): """Predict the class labels on the test data Parameters @@ -648,6 +656,10 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Test data, where n_samples is the number of observations and n_features is the number of features. + trainX : array, shape = [n_train_samples, n_features] + Only for nonlinear prediction with kernels: the training data used + to train the model. + Returns ------- y_pred : array-like, shape = (n_samples, ) @@ -655,7 +667,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): """ _skl_check_is_fitted(self, "predict", self.refit) - return self.best_estimator_.predict(X) + return self.best_estimator_.predict(X, trainX=trainX) def load_grid_tiny(): -- cgit v1.2.3