diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 12:21:24 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 12:21:24 +0000 |
| commit | d1ddd504802072d930170b802d2cf98fb309cd46 (patch) | |
| tree | 2421ba88a37d686eca467cf85960ab7da4ae991a /gensvm/core.py | |
| parent | Move wrapper to better folder structure (diff) | |
| download | pygensvm-d1ddd504802072d930170b802d2cf98fb309cd46.tar.gz pygensvm-d1ddd504802072d930170b802d2cf98fb309cd46.zip | |
Code formatting with Black
Diffstat (limited to 'gensvm/core.py')
| -rw-r--r-- | gensvm/core.py | 179 |
1 files changed, 129 insertions, 50 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index 77a3a7f..edd5236 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -21,9 +21,25 @@ from sklearn.utils.validation import check_is_fitted from .cython_wrapper import wrapper -def _fit_gensvm(X, y, n_class, p, lmd, kappa, epsilon, weights, kernel, gamma, - coef, degree, kernel_eigen_cutoff, verbose, max_iter, - random_state=None, seed_V=None): +def _fit_gensvm( + X, + y, + n_class, + p, + lmd, + kappa, + epsilon, + weights, + kernel, + gamma, + coef, + degree, + kernel_eigen_cutoff, + verbose, + max_iter, + random_state=None, + seed_V=None, +): # process the random state rnd = check_random_state(random_state) @@ -32,23 +48,41 @@ def _fit_gensvm(X, y, n_class, p, lmd, kappa, epsilon, weights, kernel, gamma, wrapper.set_verbosity_wrap(verbose) # convert the weight index - weight_idx = 1 if weights == 'unit' else 2 + weight_idx = 1 if weights == "unit" else 2 # run the actual training raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap( - X, y, n_class, p, lmd, kappa, epsilon, weight_idx, kernel, gamma, - coef, degree, kernel_eigen_cutoff, max_iter, - rnd.randint(np.iinfo('i').max), seed_V) + X, + y, + n_class, + p, + lmd, + kappa, + epsilon, + weight_idx, + kernel, + gamma, + coef, + degree, + kernel_eigen_cutoff, + max_iter, + rnd.randint(np.iinfo("i").max), + seed_V, + ) # process output if status_ == 1 and verbose > 0: - warnings.warn("GenSVM optimization prematurely ended due to a " - "incorrect step in the optimization algorithm.", - FitFailedWarning) + warnings.warn( + "GenSVM optimization prematurely ended due to a " + "incorrect step in the optimization algorithm.", + FitFailedWarning, + ) if status_ == 2 and verbose > 0: - warnings.warn("GenSVM failed to converge, increase " - "the number of iterations.", ConvergenceWarning) + warnings.warn( + "GenSVM failed to converge, increase " "the number of iterations.", + ConvergenceWarning, + ) coef_ = raw_coef_[1:, :] intercept_ = raw_coef_[0, :] @@ -141,32 +175,53 @@ class GenSVM(BaseEstimator, ClassifierMixin): """ - def __init__(self, p=1.0, lmd=1e-5, kappa=0.0, epsilon=1e-6, - weights='unit', kernel='linear', gamma='auto', coef=1.0, - degree=2.0, kernel_eigen_cutoff=1e-8, verbose=0, random_state=None, - max_iter=1e8): + def __init__( + self, + p=1.0, + lmd=1e-5, + kappa=0.0, + epsilon=1e-6, + weights="unit", + kernel="linear", + gamma="auto", + coef=1.0, + degree=2.0, + kernel_eigen_cutoff=1e-8, + verbose=0, + random_state=None, + max_iter=1e8, + ): if not 1.0 <= p <= 2.0: - raise ValueError("Value for p should be within [1, 2]; got p = %r" - % p) + raise ValueError( + "Value for p should be within [1, 2]; got p = %r" % p + ) if not kappa > -1.0: - raise ValueError("Value for kappa should be larger than -1; got " - "kappa = %r" % kappa) + raise ValueError( + "Value for kappa should be larger than -1; got " + "kappa = %r" % kappa + ) if not lmd > 0: - raise ValueError("Value for lmd should be larger than 0; got " - "lmd = %r" % lmd) + raise ValueError( + "Value for lmd should be larger than 0; got " "lmd = %r" % lmd + ) if not epsilon > 0: - raise ValueError("Value for epsilon should be larger than 0; got " - "epsilon = %r" % epsilon) + raise ValueError( + "Value for epsilon should be larger than 0; got " + "epsilon = %r" % epsilon + ) if gamma == 0.0: raise ValueError("A gamma value of 0.0 is invalid") - if not weights in ('unit', 'group'): - raise ValueError("Unknown weight parameter specified. Should be " - "'unit' or 'group'; got %r" % weights) - if not kernel in ('linear', 'rbf', 'poly', 'sigmoid'): - raise ValueError("Unknown kernel specified. Should be " - "'linear', 'rbf', 'poly', or 'sigmoid'; got %r" % kernel) - + if not weights in ("unit", "group"): + raise ValueError( + "Unknown weight parameter specified. Should be " + "'unit' or 'group'; got %r" % weights + ) + if not kernel in ("linear", "rbf", "poly", "sigmoid"): + raise ValueError( + "Unknown kernel specified. Should be " + "'linear', 'rbf', 'poly', or 'sigmoid'; got %r" % kernel + ) self.p = p self.lmd = lmd @@ -182,7 +237,6 @@ class GenSVM(BaseEstimator, ClassifierMixin): self.random_state = random_state self.max_iter = max_iter - def fit(self, X, y, seed_V=None): """Fit the GenSVM model on the given data @@ -219,44 +273,69 @@ class GenSVM(BaseEstimator, ClassifierMixin): Returns self. """ - X, y_org = check_X_y(X, y, accept_sparse=False, dtype=np.float64, - order="C") + X, y_org = check_X_y( + X, y, accept_sparse=False, dtype=np.float64, order="C" + ) y_type = type_of_target(y_org) if y_type not in ["binary", "multiclass"]: raise ValueError("Label type not allowed for GenSVM: %r" % y_type) - if self.gamma == 'auto': + if self.gamma == "auto": gamma = 1 / X.shape[1] else: gamma = self.gamma - # This is necessary because GenSVM expects classes to go from 1 to + # This is necessary because GenSVM expects classes to go from 1 to # n_class self.encoder = LabelEncoder() y = self.encoder.fit_transform(y_org) y += 1 n_class = len(np.unique(y)) - if not seed_V is None and self.kernel != 'linear': - warnings.warn("Warm starts are only supported for the " - "linear kernel. The seed_V parameter will be ignored.") + if not seed_V is None and self.kernel != "linear": + warnings.warn( + "Warm starts are only supported for the " + "linear kernel. The seed_V parameter will be ignored." + ) seed_V = None if not seed_V is None: n_samples, n_features = X.shape if seed_V.shape[1] + 1 > n_class: n_class = seed_V.shape[1] - if seed_V.shape[0] - 1 != n_features or (seed_V.shape[1] + 1 < - n_class): - raise ValueError("Seed V must have shape [%i, %i], " - "but has shape [%i, %i]" % (n_features+1, n_class-1, - seed_V.shape[0], seed_V.shape[1])) - - self.coef_, self.intercept_, self.n_iter_, self.n_support_ = \ - _fit_gensvm(X, y, n_class, self.p, self.lmd, self.kappa, - self.epsilon, self.weights, self.kernel, gamma, - self.coef, self.degree, self.kernel_eigen_cutoff, - self.verbose, self.max_iter, self.random_state, seed_V) + if seed_V.shape[0] - 1 != n_features or ( + seed_V.shape[1] + 1 < n_class + ): + raise ValueError( + "Seed V must have shape [%i, %i], " + "but has shape [%i, %i]" + % ( + n_features + 1, + n_class - 1, + seed_V.shape[0], + seed_V.shape[1], + ) + ) + + self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm( + X, + y, + n_class, + self.p, + self.lmd, + self.kappa, + self.epsilon, + self.weights, + self.kernel, + gamma, + self.coef, + self.degree, + self.kernel_eigen_cutoff, + self.verbose, + self.max_iter, + self.random_state, + seed_V, + ) return self def predict(self, X): |
