diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 19:15:18 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 19:15:18 -0500 |
| commit | 4efec8aa479be05cd548f91458ac26a4ecff34f2 (patch) | |
| tree | b61719ea7ac3bb8c6ca5c7aa242d00e0afb07d94 /gensvm | |
| parent | Check to make sure predict arrays are contiguous (diff) | |
| download | pygensvm-4efec8aa479be05cd548f91458ac26a4ecff34f2.tar.gz pygensvm-4efec8aa479be05cd548f91458ac26a4ecff34f2.zip | |
Bugfix and test for predict method of GridSearch
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/gridsearch.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index ed12b97..24e0a74 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -598,6 +598,14 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): ] if self.refit: + # when we used a nonlinear kernel and specified no gamma, then + # gamma='auto' was used. We need to save the actual numerical value + # for use in the predict method later on, so we extract that here. + if ( + not self.best_params_["kernel"] == "linear" + and not "gamma" in self.best_params_ + ): + self.best_params_["gamma"] = 1. / X.shape[1] self.best_estimator_ = GenSVM(**self.best_params_) # y_orig because GenSVM fit must know the conversion for predict to # work correctly @@ -639,7 +647,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.multimetric_, ) - def predict(self, X): + def predict(self, X, trainX=None): """Predict the class labels on the test data Parameters @@ -648,6 +656,10 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Test data, where n_samples is the number of observations and n_features is the number of features. + trainX : array, shape = [n_train_samples, n_features] + Only for nonlinear prediction with kernels: the training data used + to train the model. + Returns ------- y_pred : array-like, shape = (n_samples, ) @@ -655,7 +667,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): """ _skl_check_is_fitted(self, "predict", self.refit) - return self.best_estimator_.predict(X) + return self.best_estimator_.predict(X, trainX=trainX) def load_grid_tiny(): |
