aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-07 20:07:12 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-07 20:07:12 -0500
commitff76f649a9f776bc006ba49607c692466ae09271 (patch)
tree68f33de95bb394a5b34064abec17fef2cfc7cca7 /gensvm
parentbump version (diff)
downloadpygensvm-ff76f649a9f776bc006ba49607c692466ae09271.tar.gz
pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.zip
Add warning that shufflesplits unsupported
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/gridsearch.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py
index 7d453c9..0eeb603 100644
--- a/gensvm/gridsearch.py
+++ b/gensvm/gridsearch.py
@@ -17,7 +17,12 @@ import time
from operator import itemgetter
from sklearn.base import ClassifierMixin, BaseEstimator, MetaEstimatorMixin
-from sklearn.model_selection import ParameterGrid, check_cv
+from sklearn.model_selection import (
+ ParameterGrid,
+ check_cv,
+ ShuffleSplit,
+ StratifiedShuffleSplit,
+)
from sklearn.model_selection._search import _check_param_grid
from sklearn.model_selection._validation import _score
from sklearn.preprocessing import LabelEncoder
@@ -328,6 +333,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
Refer to the `scikit-learn User Guide on cross validation`_ for the
various strategies that can be used here.
+ NOTE: At the moment, the ShuffleSplit and StratifiedShuffleSplit are
+ not supported in this class. If you need these, you can use the GenSVM
+ classifier directly with the GridSearchCV object from scikit-learn.
+ (these methods require significant changes in the low-level library
+ before they can be supported).
+
refit : boolean, or string, default=True
Refit the GenSVM estimator with the best found parameters on the whole
dataset.
@@ -520,6 +531,12 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
self.scoring = scoring
self.cv = 5 if cv is None else cv
+ if isinstance(self.cv, ShuffleSplit) or isinstance(
+ self.cv, StratifiedShuffleSplit
+ ):
+ raise ValueError(
+ "ShuffleSplit and StratifiedShuffleSplit are not supported at the moment. Please see the documentation for more info"
+ )
self.refit = refit
self.verbose = verbose
self.return_train_score = return_train_score