diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
| commit | ff76f649a9f776bc006ba49607c692466ae09271 (patch) | |
| tree | 68f33de95bb394a5b34064abec17fef2cfc7cca7 | |
| parent | bump version (diff) | |
| download | pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.tar.gz pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.zip | |
Add warning that shufflesplits unsupported
| -rw-r--r-- | gensvm/gridsearch.py | 19 | ||||
| -rw-r--r-- | test/test_gridsearch.py | 18 |
2 files changed, 35 insertions, 2 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 7d453c9..0eeb603 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -17,7 +17,12 @@ import time from operator import itemgetter from sklearn.base import ClassifierMixin, BaseEstimator, MetaEstimatorMixin -from sklearn.model_selection import ParameterGrid, check_cv +from sklearn.model_selection import ( + ParameterGrid, + check_cv, + ShuffleSplit, + StratifiedShuffleSplit, +) from sklearn.model_selection._search import _check_param_grid from sklearn.model_selection._validation import _score from sklearn.preprocessing import LabelEncoder @@ -328,6 +333,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Refer to the `scikit-learn User Guide on cross validation`_ for the various strategies that can be used here. + NOTE: At the moment, the ShuffleSplit and StratifiedShuffleSplit are + not supported in this class. If you need these, you can use the GenSVM + classifier directly with the GridSearchCV object from scikit-learn. + (these methods require significant changes in the low-level library + before they can be supported). + refit : boolean, or string, default=True Refit the GenSVM estimator with the best found parameters on the whole dataset. @@ -520,6 +531,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): self.scoring = scoring self.cv = 5 if cv is None else cv + if isinstance(self.cv, ShuffleSplit) or isinstance( + self.cv, StratifiedShuffleSplit + ): + raise ValueError( + "ShuffleSplit and StratifiedShuffleSplit are not supported at the moment. Please see the documentation for more info" + ) self.refit = refit self.verbose = verbose self.return_train_score = return_train_score diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index c723949..274d1bc 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -12,7 +12,11 @@ import numpy as np import unittest from sklearn.datasets import load_iris, load_digits -from sklearn.model_selection import train_test_split +from sklearn.model_selection import ( + train_test_split, + StratifiedShuffleSplit, + ShuffleSplit, +) from sklearn.preprocessing import maxabs_scale from gensvm.gridsearch import ( @@ -267,3 +271,15 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): # low threshold on purpose for testing on Travis # Real performance should be higher! self.assertGreaterEqual(score, 0.70) + + def test_gridsearch_stratified(self): + """ GENSVM_GRID: Error on using shufflesplit """ + X, y = load_iris(return_X_y=True) + + cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) + + cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) |
