From ff76f649a9f776bc006ba49607c692466ae09271 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Thu, 7 Mar 2019 20:07:12 -0500 Subject: Add warning that shufflesplits unsupported --- test/test_gridsearch.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py index c723949..274d1bc 100644 --- a/test/test_gridsearch.py +++ b/test/test_gridsearch.py @@ -12,7 +12,11 @@ import numpy as np import unittest from sklearn.datasets import load_iris, load_digits -from sklearn.model_selection import train_test_split +from sklearn.model_selection import ( + train_test_split, + StratifiedShuffleSplit, + ShuffleSplit, +) from sklearn.preprocessing import maxabs_scale from gensvm.gridsearch import ( @@ -267,3 +271,15 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase): # low threshold on purpose for testing on Travis # Real performance should be higher! self.assertGreaterEqual(score, 0.70) + + def test_gridsearch_stratified(self): + """ GENSVM_GRID: Error on using shufflesplit """ + X, y = load_iris(return_X_y=True) + + cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) + + cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42) + with self.assertRaises(ValueError): + GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv) -- cgit v1.2.3