aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-07 20:07:12 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-07 20:07:12 -0500
commitff76f649a9f776bc006ba49607c692466ae09271 (patch)
tree68f33de95bb394a5b34064abec17fef2cfc7cca7 /test
parentbump version (diff)
downloadpygensvm-ff76f649a9f776bc006ba49607c692466ae09271.tar.gz
pygensvm-ff76f649a9f776bc006ba49607c692466ae09271.zip
Add warning that shufflesplits unsupported
Diffstat (limited to 'test')
-rw-r--r--test/test_gridsearch.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/test/test_gridsearch.py b/test/test_gridsearch.py
index c723949..274d1bc 100644
--- a/test/test_gridsearch.py
+++ b/test/test_gridsearch.py
@@ -12,7 +12,11 @@ import numpy as np
import unittest
from sklearn.datasets import load_iris, load_digits
-from sklearn.model_selection import train_test_split
+from sklearn.model_selection import (
+ train_test_split,
+ StratifiedShuffleSplit,
+ ShuffleSplit,
+)
from sklearn.preprocessing import maxabs_scale
from gensvm.gridsearch import (
@@ -267,3 +271,15 @@ class GenSVMGridSearchCVTestCase(unittest.TestCase):
# low threshold on purpose for testing on Travis
# Real performance should be higher!
self.assertGreaterEqual(score, 0.70)
+
+ def test_gridsearch_stratified(self):
+ """ GENSVM_GRID: Error on using shufflesplit """
+ X, y = load_iris(return_X_y=True)
+
+ cv = ShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
+ with self.assertRaises(ValueError):
+ GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv)
+
+ cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
+ with self.assertRaises(ValueError):
+ GenSVMGridSearchCV(param_grid="tiny", verbose=1, cv=cv)