aboutsummaryrefslogtreecommitdiff
path: root/gensvm/core.py
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-01-15 12:21:24 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-01-15 12:21:24 +0000
commitd1ddd504802072d930170b802d2cf98fb309cd46 (patch)
tree2421ba88a37d686eca467cf85960ab7da4ae991a /gensvm/core.py
parentMove wrapper to better folder structure (diff)
downloadpygensvm-d1ddd504802072d930170b802d2cf98fb309cd46.tar.gz
pygensvm-d1ddd504802072d930170b802d2cf98fb309cd46.zip
Code formatting with Black
Diffstat (limited to 'gensvm/core.py')
-rw-r--r--gensvm/core.py179
1 files changed, 129 insertions, 50 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index 77a3a7f..edd5236 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -21,9 +21,25 @@ from sklearn.utils.validation import check_is_fitted
from .cython_wrapper import wrapper
-def _fit_gensvm(X, y, n_class, p, lmd, kappa, epsilon, weights, kernel, gamma,
- coef, degree, kernel_eigen_cutoff, verbose, max_iter,
- random_state=None, seed_V=None):
+def _fit_gensvm(
+ X,
+ y,
+ n_class,
+ p,
+ lmd,
+ kappa,
+ epsilon,
+ weights,
+ kernel,
+ gamma,
+ coef,
+ degree,
+ kernel_eigen_cutoff,
+ verbose,
+ max_iter,
+ random_state=None,
+ seed_V=None,
+):
# process the random state
rnd = check_random_state(random_state)
@@ -32,23 +48,41 @@ def _fit_gensvm(X, y, n_class, p, lmd, kappa, epsilon, weights, kernel, gamma,
wrapper.set_verbosity_wrap(verbose)
# convert the weight index
- weight_idx = 1 if weights == 'unit' else 2
+ weight_idx = 1 if weights == "unit" else 2
# run the actual training
raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap(
- X, y, n_class, p, lmd, kappa, epsilon, weight_idx, kernel, gamma,
- coef, degree, kernel_eigen_cutoff, max_iter,
- rnd.randint(np.iinfo('i').max), seed_V)
+ X,
+ y,
+ n_class,
+ p,
+ lmd,
+ kappa,
+ epsilon,
+ weight_idx,
+ kernel,
+ gamma,
+ coef,
+ degree,
+ kernel_eigen_cutoff,
+ max_iter,
+ rnd.randint(np.iinfo("i").max),
+ seed_V,
+ )
# process output
if status_ == 1 and verbose > 0:
- warnings.warn("GenSVM optimization prematurely ended due to a "
- "incorrect step in the optimization algorithm.",
- FitFailedWarning)
+ warnings.warn(
+ "GenSVM optimization prematurely ended due to a "
+ "incorrect step in the optimization algorithm.",
+ FitFailedWarning,
+ )
if status_ == 2 and verbose > 0:
- warnings.warn("GenSVM failed to converge, increase "
- "the number of iterations.", ConvergenceWarning)
+ warnings.warn(
+ "GenSVM failed to converge, increase " "the number of iterations.",
+ ConvergenceWarning,
+ )
coef_ = raw_coef_[1:, :]
intercept_ = raw_coef_[0, :]
@@ -141,32 +175,53 @@ class GenSVM(BaseEstimator, ClassifierMixin):
"""
- def __init__(self, p=1.0, lmd=1e-5, kappa=0.0, epsilon=1e-6,
- weights='unit', kernel='linear', gamma='auto', coef=1.0,
- degree=2.0, kernel_eigen_cutoff=1e-8, verbose=0, random_state=None,
- max_iter=1e8):
+ def __init__(
+ self,
+ p=1.0,
+ lmd=1e-5,
+ kappa=0.0,
+ epsilon=1e-6,
+ weights="unit",
+ kernel="linear",
+ gamma="auto",
+ coef=1.0,
+ degree=2.0,
+ kernel_eigen_cutoff=1e-8,
+ verbose=0,
+ random_state=None,
+ max_iter=1e8,
+ ):
if not 1.0 <= p <= 2.0:
- raise ValueError("Value for p should be within [1, 2]; got p = %r"
- % p)
+ raise ValueError(
+ "Value for p should be within [1, 2]; got p = %r" % p
+ )
if not kappa > -1.0:
- raise ValueError("Value for kappa should be larger than -1; got "
- "kappa = %r" % kappa)
+ raise ValueError(
+ "Value for kappa should be larger than -1; got "
+ "kappa = %r" % kappa
+ )
if not lmd > 0:
- raise ValueError("Value for lmd should be larger than 0; got "
- "lmd = %r" % lmd)
+ raise ValueError(
+ "Value for lmd should be larger than 0; got " "lmd = %r" % lmd
+ )
if not epsilon > 0:
- raise ValueError("Value for epsilon should be larger than 0; got "
- "epsilon = %r" % epsilon)
+ raise ValueError(
+ "Value for epsilon should be larger than 0; got "
+ "epsilon = %r" % epsilon
+ )
if gamma == 0.0:
raise ValueError("A gamma value of 0.0 is invalid")
- if not weights in ('unit', 'group'):
- raise ValueError("Unknown weight parameter specified. Should be "
- "'unit' or 'group'; got %r" % weights)
- if not kernel in ('linear', 'rbf', 'poly', 'sigmoid'):
- raise ValueError("Unknown kernel specified. Should be "
- "'linear', 'rbf', 'poly', or 'sigmoid'; got %r" % kernel)
-
+ if not weights in ("unit", "group"):
+ raise ValueError(
+ "Unknown weight parameter specified. Should be "
+ "'unit' or 'group'; got %r" % weights
+ )
+ if not kernel in ("linear", "rbf", "poly", "sigmoid"):
+ raise ValueError(
+ "Unknown kernel specified. Should be "
+ "'linear', 'rbf', 'poly', or 'sigmoid'; got %r" % kernel
+ )
self.p = p
self.lmd = lmd
@@ -182,7 +237,6 @@ class GenSVM(BaseEstimator, ClassifierMixin):
self.random_state = random_state
self.max_iter = max_iter
-
def fit(self, X, y, seed_V=None):
"""Fit the GenSVM model on the given data
@@ -219,44 +273,69 @@ class GenSVM(BaseEstimator, ClassifierMixin):
Returns self.
"""
- X, y_org = check_X_y(X, y, accept_sparse=False, dtype=np.float64,
- order="C")
+ X, y_org = check_X_y(
+ X, y, accept_sparse=False, dtype=np.float64, order="C"
+ )
y_type = type_of_target(y_org)
if y_type not in ["binary", "multiclass"]:
raise ValueError("Label type not allowed for GenSVM: %r" % y_type)
- if self.gamma == 'auto':
+ if self.gamma == "auto":
gamma = 1 / X.shape[1]
else:
gamma = self.gamma
- # This is necessary because GenSVM expects classes to go from 1 to
+ # This is necessary because GenSVM expects classes to go from 1 to
# n_class
self.encoder = LabelEncoder()
y = self.encoder.fit_transform(y_org)
y += 1
n_class = len(np.unique(y))
- if not seed_V is None and self.kernel != 'linear':
- warnings.warn("Warm starts are only supported for the "
- "linear kernel. The seed_V parameter will be ignored.")
+ if not seed_V is None and self.kernel != "linear":
+ warnings.warn(
+ "Warm starts are only supported for the "
+ "linear kernel. The seed_V parameter will be ignored."
+ )
seed_V = None
if not seed_V is None:
n_samples, n_features = X.shape
if seed_V.shape[1] + 1 > n_class:
n_class = seed_V.shape[1]
- if seed_V.shape[0] - 1 != n_features or (seed_V.shape[1] + 1 <
- n_class):
- raise ValueError("Seed V must have shape [%i, %i], "
- "but has shape [%i, %i]" % (n_features+1, n_class-1,
- seed_V.shape[0], seed_V.shape[1]))
-
- self.coef_, self.intercept_, self.n_iter_, self.n_support_ = \
- _fit_gensvm(X, y, n_class, self.p, self.lmd, self.kappa,
- self.epsilon, self.weights, self.kernel, gamma,
- self.coef, self.degree, self.kernel_eigen_cutoff,
- self.verbose, self.max_iter, self.random_state, seed_V)
+ if seed_V.shape[0] - 1 != n_features or (
+ seed_V.shape[1] + 1 < n_class
+ ):
+ raise ValueError(
+ "Seed V must have shape [%i, %i], "
+ "but has shape [%i, %i]"
+ % (
+ n_features + 1,
+ n_class - 1,
+ seed_V.shape[0],
+ seed_V.shape[1],
+ )
+ )
+
+ self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm(
+ X,
+ y,
+ n_class,
+ self.p,
+ self.lmd,
+ self.kappa,
+ self.epsilon,
+ self.weights,
+ self.kernel,
+ gamma,
+ self.coef,
+ self.degree,
+ self.kernel_eigen_cutoff,
+ self.verbose,
+ self.max_iter,
+ self.random_state,
+ seed_V,
+ )
return self
def predict(self, X):