diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-07 20:07:12 -0500 |
| commit | ff76f649a9f776bc006ba49607c692466ae09271 (patch) | |
| tree | 68f33de95bb394a5b34064abec17fef2cfc7cca7 /test | |
| parent | bump version (diff) | |
| download | pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.tar.gz pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.zip | |
Add warning that shufflesplits unsupported
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_gridsearch.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index c723949..274d1bc 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -12,7 +12,11 @@ import numpy as np import unittest from sklearn.datasets import load_iris, load_digits -from sklearn.model_selection import train_test_split +from sklearn.model_selection import ( + train_test_split, + StratifiedShuffleSplit, + ShuffleSplit, +) from sklearn.preprocessing import maxabs_scale from gensvm.gridsearch import ( @@ -267,3 +271,15 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): # low threshold on purpose for testing on Travis # Real performance should be higher! self.assertGreaterEqual(score, 0.70) + + def test_gridsearch_stratified(self): + """ GENSVM_GRID: Error on using shufflesplit """ + X, y = load_iris(return_X_y=True) + + cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) + + cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) |
