diff options
Diffstat (limited to 'gensvm/gridsearch.py')
| -rw-r--r-- | gensvm/gridsearch.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 7d453c9..0eeb603 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -17,7 +17,12 @@ import time from operator import itemgetter from sklearn.base import ClassifierMixin, BaseEstimator, MetaEstimatorMixin -from sklearn.model_selection import ParameterGrid, check_cv +from sklearn.model_selection import ( + ParameterGrid, + check_cv, + ShuffleSplit, + StratifiedShuffleSplit, +) from sklearn.model_selection._search import _check_param_grid from sklearn.model_selection._validation import _score from sklearn.preprocessing import LabelEncoder @@ -328,6 +333,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Refer to the `scikit-learn User Guide on cross validation`_ for the various strategies that can be used here. + NOTE: At the moment, the ShuffleSplit and StratifiedShuffleSplit are + not supported in this class. If you need these, you can use the GenSVM + classifier directly with the GridSearchCV object from scikit-learn. + (these methods require significant changes in the low-level library + before they can be supported). + refit : boolean, or string, default=True Refit the GenSVM estimator with the best found parameters on the whole dataset. @@ -520,6 +531,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.scoring = scoring self.cv = 5 if cv is None else cv + if isinstance(self.cv, ShuffleSplit) or isinstance( + self.cv, StratifiedShuffleSplit + ): + raise ValueError( + "ShuffleSplit and StratifiedShuffleSplit are not supported at the moment. Please see the documentation for more info" + ) self.refit = refit self.verbose = verbose self.return_train_score = return_train_score |
