diff options
| -rw-r--r-- | CHANGELOG.rst | 5 | ||||
| -rw-r--r-- | gensvm/__init__.py | 2 | ||||
| -rw-r--r-- | gensvm/core.py | 4 | ||||
| -rw-r--r-- | test/test_core.py | 17 |
4 files changed, 26 insertions, 2 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fd5bedb..be4d1df 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,11 @@ Change Log ---------- +Version 0.2.3 +^^^^^^^^^^^^^ + +- Bugfix for prediction with gamma = 'auto' + Version 0.2.2 ^^^^^^^^^^^^^ diff --git a/gensvm/__init__.py b/gensvm/__init__.py index 62ac067..c8cd0be 100644 --- a/gensvm/__init__.py +++ b/gensvm/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "0.2.2" +__version__ = "0.2.3" from .core import GenSVM from .gridsearch import GenSVMGridSearchCV diff --git a/gensvm/core.py b/gensvm/core.py index c994995..f0e7820 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -408,6 +408,8 @@ class GenSVM(BaseEstimator, ClassifierMixin): trainX, accept_sparse=False, dtype=np.float64, order="C" ) + gamma = 1.0 / X.shape[1] if self.gamma == "auto" else self.gamma + V = self.combined_coef_ if self.kernel == "linear": predictions = wrapper.predict_wrap(X, V) @@ -420,7 +422,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): V, n_class, kernel_idx, - self.gamma, + gamma, self.coef, self.degree, self.kernel_eigen_cutoff, diff --git a/test/test_core.py b/test/test_core.py index 0a1042a..63c172d 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -197,6 +197,23 @@ class GenSVMTestCase(unittest.TestCase): set(pred).issubset(set(["versicolor", "virginica", "setosa"])) ) + def test_fit_nonlinear_auto(self): + """ GENSVM: Fit and predict with nonlinear kernel """ + data = load_iris() + X = data.data + y = data.target_names[data.target] + + X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=123 + ) + clf = GenSVM(kernel="rbf", max_iter=1000, random_state=123) + clf.fit(X_train, y_train) + + pred = clf.predict(X_test, trainX=X_train) + self.assertTrue( + set(pred).issubset(set(["versicolor", "virginica", "setosa"])) + ) + def test_fit_with_seed(self): """ GENSVM: Test fit with seeding """ |
