diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
| commit | ff76f649a9f776bc006ba49607c692466ae09271 (patch) | |
| tree | 68f33de95bb394a5b34064abec17fef2cfc7cca7 /gensvm | |
| parent | bump version (diff) | |
| download | pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.tar.gz pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.zip | |
Add warning that shufflesplits unsupported
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/gridsearch.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 7d453c9..0eeb603 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -17,7 +17,12 @@ import time from operator import itemgetter from sklearn.base import ClassifierMixin, BaseEstimator, MetaEstimatorMixin -from sklearn.model_selection import ParameterGrid, check_cv +from sklearn.model_selection import ( + ParameterGrid, + check_cv, + ShuffleSplit, + StratifiedShuffleSplit, +) from sklearn.model_selection._search import _check_param_grid from sklearn.model_selection._validation import _score from sklearn.preprocessing import LabelEncoder @@ -328,6 +333,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Refer to the `scikit-learn User Guide on cross validation`_ for the various strategies that can be used here. + NOTE: At the moment, the ShuffleSplit and StratifiedShuffleSplit are + not supported in this class. If you need these, you can use the GenSVM + classifier directly with the GridSearchCV object from scikit-learn. + (these methods require significant changes in the low-level library + before they can be supported). + refit : boolean, or string, default=True Refit the GenSVM estimator with the best found parameters on the whole dataset. @@ -520,6 +531,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.scoring = scoring self.cv = 5 if cv is None else cv + if isinstance(self.cv, ShuffleSplit) or isinstance( + self.cv, StratifiedShuffleSplit + ): + raise ValueError( + "ShuffleSplit and StratifiedShuffleSplit are not supported at the moment. Please see the documentation for more info" + ) self.refit = refit self.verbose = verbose self.return_train_score = return_train_score |
