From 4efec8aa479be05cd548f91458ac26a4ecff34f2 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 6 Mar 2019 19:15:18 -0500 Subject: Bugfix and test for predict method of GridSearch --- gensvm/gridsearch.py | 16 ++++++++++++++-- test/test_gridsearch.py | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index ed12b97..24e0a74 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -598,6 +598,14 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): ] if self.refit: + # when we used a nonlinear kernel and specified no gamma, then + # gamma='auto' was used. We need to save the actual numerical value + # for use in the predict method later on, so we extract that here. + if ( + not self.best_params_["kernel"] == "linear" + and not "gamma" in self.best_params_ + ): + self.best_params_["gamma"] = 1. / X.shape[1] self.best_estimator_ = GenSVM(**self.best_params_) # y_orig because GenSVM fit must know the conversion for predict to # work correctly @@ -639,7 +647,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.multimetric_, ) - def predict(self, X): + def predict(self, X, trainX=None): """Predict the class labels on the test data Parameters @@ -648,6 +656,10 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Test data, where n_samples is the number of observations and n_features is the number of features. + trainX : array, shape = [n_train_samples, n_features] + Only for nonlinear prediction with kernels: the training data used + to train the model. + Returns ------- y_pred : array-like, shape = (n_samples, ) @@ -655,7 +667,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): """ _skl_check_is_fitted(self, "predict", self.refit) - return self.best_estimator_.predict(X) + return self.best_estimator_.predict(X, trainX=trainX) def load_grid_tiny(): diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index 1a29b0a..f4cff7e 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -172,6 +172,9 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): self.assertTrue(hasattr(clf, "best_params_")) + y_pred = clf.predict(X_test, trainX=X_train) + del y_pred + def test_invalid_y(self): """ GENSVM_GRID: Check raises for invalid y type """ pg = {"lmd": [1e-4, 100, 10000], "kernel": ["rbf"]} -- cgit v1.2.3