aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-05-17 15:17:32 -0400
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-05-17 15:17:32 -0400
commitad5b0a7800518baf4ac450257d045e5fa2e735d0 (patch)
treef16b60a75ffa60026e1b165a7a216db02442c0bb
parentupdate version (diff)
downloadpygensvm-ad5b0a7800518baf4ac450257d045e5fa2e735d0.tar.gz
pygensvm-ad5b0a7800518baf4ac450257d045e5fa2e735d0.zip
bugfix for predict with gamma = 'auto'
-rw-r--r--CHANGELOG.rst5
-rw-r--r--gensvm/__init__.py2
-rw-r--r--gensvm/core.py4
-rw-r--r--test/test_core.py17
4 files changed, 26 insertions, 2 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index fd5bedb..be4d1df 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,11 @@
Change Log
----------
+Version 0.2.3
+^^^^^^^^^^^^^
+
+- Bugfix for prediction with gamma = 'auto'
+
Version 0.2.2
^^^^^^^^^^^^^
diff --git a/gensvm/__init__.py b/gensvm/__init__.py
index 62ac067..c8cd0be 100644
--- a/gensvm/__init__.py
+++ b/gensvm/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-__version__ = "0.2.2"
+__version__ = "0.2.3"
from .core import GenSVM
from .gridsearch import GenSVMGridSearchCV
diff --git a/gensvm/core.py b/gensvm/core.py
index c994995..f0e7820 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -408,6 +408,8 @@ class GenSVM(BaseEstimator, ClassifierMixin):
trainX, accept_sparse=False, dtype=np.float64, order="C"
)
+ gamma = 1.0 / X.shape[1] if self.gamma == "auto" else self.gamma
+
V = self.combined_coef_
if self.kernel == "linear":
predictions = wrapper.predict_wrap(X, V)
@@ -420,7 +422,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
V,
n_class,
kernel_idx,
- self.gamma,
+ gamma,
self.coef,
self.degree,
self.kernel_eigen_cutoff,
diff --git a/test/test_core.py b/test/test_core.py
index 0a1042a..63c172d 100644
--- a/test/test_core.py
+++ b/test/test_core.py
@@ -197,6 +197,23 @@ class GenSVMTestCase(unittest.TestCase):
set(pred).issubset(set(["versicolor", "virginica", "setosa"]))
)
+ def test_fit_nonlinear_auto(self):
+ """ GENSVM: Fit and predict with nonlinear kernel """
+ data = load_iris()
+ X = data.data
+ y = data.target_names[data.target]
+
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, random_state=123
+ )
+ clf = GenSVM(kernel="rbf", max_iter=1000, random_state=123)
+ clf.fit(X_train, y_train)
+
+ pred = clf.predict(X_test, trainX=X_train)
+ self.assertTrue(
+ set(pred).issubset(set(["versicolor", "virginica", "setosa"]))
+ )
+
def test_fit_with_seed(self):
""" GENSVM: Test fit with seeding """