aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 17:49:08 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 17:49:08 +0000
commitab69492525a49e64500da97fc6798b2b436db3f0 (patch)
treea4630cd6cc17437949827062a2b9639640bbbf73
parentUpdate sphinx config (diff)
downloadpygensvm-ab69492525a49e64500da97fc6798b2b436db3f0.tar.gz
pygensvm-ab69492525a49e64500da97fc6798b2b436db3f0.zip
Documentation improvements
-rw-r--r--gensvm/core.py33
-rw-r--r--gensvm/gridsearch.py59
2 files changed, 65 insertions, 27 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index 45d59ad..169a30c 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -104,7 +104,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
errors. It is this flexibility that makes it perform well on diverse
datasets.
- The :func:`~GenSVM.fit` and :func:`~GenSVM.predict` methods of this class
+ The :meth:`~GenSVM.fit` and :meth:`~GenSVM.predict` methods of this class
use the GenSVM C library for the actual computations.
Parameters
@@ -123,7 +123,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
'group' for group size correction weights (equation 4 in the paper).
It is also possible to provide an explicit vector of sample weights
- through the :func:`~GenSVM.fit` method. If so, it will override the
+ through the :meth:`~GenSVM.fit` method. If so, it will override the
setting provided here.
kernel : string, optional (default='linear')
@@ -183,7 +183,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
See Also
--------
- :class:`.GenSVMGridSearchCV`:
+ :class:`~.gridsearch.GenSVMGridSearchCV`:
Helper class to run an efficient grid search for GenSVM.
@@ -257,8 +257,8 @@ class GenSVM(BaseEstimator, ClassifierMixin):
def fit(self, X, y, sample_weight=None, seed_V=None):
"""Fit the GenSVM model on the given data
- The model can be fit with or without a seed matrix (``seed_V``). This
- can be used to provide warm starts for the algorithm.
+ The model can be fit with or without a seed matrix (`seed_V`). This can
+ be used to provide warm starts for the algorithm.
Parameters
----------
@@ -280,14 +280,13 @@ class GenSVM(BaseEstimator, ClassifierMixin):
<.GenSVM.combined_coef_>` attribute of a different GenSVM model.
This is only supported for the linear kernel.
- NOTE: the size of the seed_V matrix is ``n_features+1`` by
- ``n_classes - 1``. The number of columns of ``seed_V`` is leading
- for the number of classes in the model. For example, if ``y``
- contains 3 different classes and ``seed_V`` has 3 columns, we
- assume that there are actually 4 classes in the problem but one
- class is just represented in this training data. This can be useful
- for problems were a certain class has only a few samples.
-
+ NOTE: the size of the seed_V matrix is `n_features+1` by `n_classes
+ - 1`. The number of columns of `seed_V` is leading for the number
+ of classes in the model. For example, if `y` contains 3 different
+ classes and `seed_V` has 3 columns, we assume that there are
+ actually 4 classes in the problem but one class is just
+ represented in this training data. This can be useful for
+ problems were a certain class has only a few samples.
Returns
-------
@@ -354,7 +353,13 @@ class GenSVM(BaseEstimator, ClassifierMixin):
)
)
- self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm(
+ (
+ self.coef_,
+ self.intercept_,
+ self.n_iter_,
+ self.n_support_,
+ self.SVs_,
+ ) = _fit_gensvm(
X,
y,
n_class,
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py
index 22125a4..550ae64 100644
--- a/gensvm/gridsearch.py
+++ b/gensvm/gridsearch.py
@@ -65,7 +65,7 @@ def _validate_param_grid(param_grid):
"""Check if the parameter values are valid
This basically does the same checks as in the constructor of the
- :class:`core.GenSVM` class, but for the entire parameter grid.
+ :class:`~.core.GenSVM` class, but for the entire parameter grid.
"""
# the conditions that the parameters must satisfy
@@ -169,16 +169,24 @@ def _format_results(
score_time = 0
if return_train_score:
- train_pred = predictions[cv_idx != test_idx,]
- y_train = true_y[cv_idx != test_idx,]
+ train_pred = predictions[
+ cv_idx != test_idx,
+ ]
+ y_train = true_y[
+ cv_idx != test_idx,
+ ]
train_score, score_t = _wrap_score(
train_pred, y_train, scorers, is_multimetric
)
score_time += score_t
ret.append(train_score)
- test_pred = predictions[cv_idx == test_idx,]
- y_test = true_y[cv_idx == test_idx,]
+ test_pred = predictions[
+ cv_idx == test_idx,
+ ]
+ y_test = true_y[
+ cv_idx == test_idx,
+ ]
test_score, score_t = _wrap_score(
test_pred, y_test, scorers, is_multimetric
)
@@ -232,7 +240,7 @@ def _fit_grid_gensvm(
Returns
-------
cv_results_ : dict
- The cross validation results. See :func:`~GenSVMGridSearchCV.fit`.
+ The cross validation results. See :meth:`~GenSVMGridSearchCV.fit`.
"""
@@ -349,7 +357,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
The refitted estimator is made available at the `:attr:best_estimator_
<.GenSVMGridSearchCV.best_estimator_>` attribute and allows the user to
- use the :func:`~GenSVMGridSearchCV.predict` method directly on this
+ use the :meth:`~GenSVMGridSearchCV.predict` method directly on this
:class:`.GenSVMGridSearchCV` instance.
Also for multiple metric evaluation, the attributes :attr:`best_index_
@@ -623,7 +631,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
and not self.best_params_["kernel"] == "linear"
and not "gamma" in self.best_params_
):
- self.best_params_["gamma"] = 1. / X.shape[1]
+ self.best_params_["gamma"] = 1.0 / X.shape[1]
self.best_estimator_ = GenSVM(**self.best_params_)
# y_orig because GenSVM fit must know the conversion for predict to
# work correctly
@@ -711,9 +719,24 @@ def load_grid_tiny():
"""
pg = [
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -16)], "weights": ["unit"]},
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -18)], "weights": ["unit"]},
- {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -18)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -16)],
+ "weights": ["unit"],
+ },
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -18)],
+ "weights": ["unit"],
+ },
+ {
+ "p": [2.0],
+ "kappa": [0.5],
+ "lmd": [pow(2, -18)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [5.0],
@@ -726,7 +749,12 @@ def load_grid_tiny():
"lmd": [pow(2, -18)],
"weights": ["unit"],
},
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -14)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -14)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [0.5],
@@ -739,7 +767,12 @@ def load_grid_tiny():
"lmd": [pow(2, -18)],
"weights": ["unit"],
},
- {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -16)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [0.5],
+ "lmd": [pow(2, -16)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [0.5],