diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 17:49:08 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 17:49:08 +0000 |
| commit | ab69492525a49e64500da97fc6798b2b436db3f0 (patch) | |
| tree | a4630cd6cc17437949827062a2b9639640bbbf73 | |
| parent | Update sphinx config (diff) | |
| download | pygensvm-ab69492525a49e64500da97fc6798b2b436db3f0.tar.gz pygensvm-ab69492525a49e64500da97fc6798b2b436db3f0.zip | |
Documentation improvements
| -rw-r--r-- | gensvm/core.py | 33 | ||||
| -rw-r--r-- | gensvm/gridsearch.py | 59 |
2 files changed, 65 insertions, 27 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index 45d59ad..169a30c 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -104,7 +104,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): errors. It is this flexibility that makes it perform well on diverse datasets. - The :func:`~GenSVM.fit` and :func:`~GenSVM.predict` methods of this class + The :meth:`~GenSVM.fit` and :meth:`~GenSVM.predict` methods of this class use the GenSVM C library for the actual computations. Parameters @@ -123,7 +123,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): 'group' for group size correction weights (equation 4 in the paper). It is also possible to provide an explicit vector of sample weights - through the :func:`~GenSVM.fit` method. If so, it will override the + through the :meth:`~GenSVM.fit` method. If so, it will override the setting provided here. kernel : string, optional (default='linear') @@ -183,7 +183,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): See Also -------- - :class:`.GenSVMGridSearchCV`: + :class:`~.gridsearch.GenSVMGridSearchCV`: Helper class to run an efficient grid search for GenSVM. @@ -257,8 +257,8 @@ class GenSVM(BaseEstimator, ClassifierMixin): def fit(self, X, y, sample_weight=None, seed_V=None): """Fit the GenSVM model on the given data - The model can be fit with or without a seed matrix (``seed_V``). This - can be used to provide warm starts for the algorithm. + The model can be fit with or without a seed matrix (`seed_V`). This can + be used to provide warm starts for the algorithm. Parameters ---------- @@ -280,14 +280,13 @@ class GenSVM(BaseEstimator, ClassifierMixin): <.GenSVM.combined_coef_>` attribute of a different GenSVM model. This is only supported for the linear kernel. - NOTE: the size of the seed_V matrix is ``n_features+1`` by - ``n_classes - 1``. The number of columns of ``seed_V`` is leading - for the number of classes in the model. For example, if ``y`` - contains 3 different classes and ``seed_V`` has 3 columns, we - assume that there are actually 4 classes in the problem but one - class is just represented in this training data. This can be useful - for problems were a certain class has only a few samples. - + NOTE: the size of the seed_V matrix is `n_features+1` by `n_classes + - 1`. The number of columns of `seed_V` is leading for the number + of classes in the model. For example, if `y` contains 3 different + classes and `seed_V` has 3 columns, we assume that there are + actually 4 classes in the problem but one class is just + represented in this training data. This can be useful for + problems were a certain class has only a few samples. Returns ------- @@ -354,7 +353,13 @@ class GenSVM(BaseEstimator, ClassifierMixin): ) ) - self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm( + ( + self.coef_, + self.intercept_, + self.n_iter_, + self.n_support_, + self.SVs_, + ) = _fit_gensvm( X, y, n_class, diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 22125a4..550ae64 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -65,7 +65,7 @@ def _validate_param_grid(param_grid): """Check if the parameter values are valid This basically does the same checks as in the constructor of the - :class:`core.GenSVM` class, but for the entire parameter grid. + :class:`~.core.GenSVM` class, but for the entire parameter grid. """ # the conditions that the parameters must satisfy @@ -169,16 +169,24 @@ def _format_results( score_time = 0 if return_train_score: - train_pred = predictions[cv_idx != test_idx,] - y_train = true_y[cv_idx != test_idx,] + train_pred = predictions[ + cv_idx != test_idx, + ] + y_train = true_y[ + cv_idx != test_idx, + ] train_score, score_t = _wrap_score( train_pred, y_train, scorers, is_multimetric ) score_time += score_t ret.append(train_score) - test_pred = predictions[cv_idx == test_idx,] - y_test = true_y[cv_idx == test_idx,] + test_pred = predictions[ + cv_idx == test_idx, + ] + y_test = true_y[ + cv_idx == test_idx, + ] test_score, score_t = _wrap_score( test_pred, y_test, scorers, is_multimetric ) @@ -232,7 +240,7 @@ def _fit_grid_gensvm( Returns ------- cv_results_ : dict - The cross validation results. See :func:`~GenSVMGridSearchCV.fit`. + The cross validation results. See :meth:`~GenSVMGridSearchCV.fit`. """ @@ -349,7 +357,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): The refitted estimator is made available at the `:attr:best_estimator_ <.GenSVMGridSearchCV.best_estimator_>` attribute and allows the user to - use the :func:`~GenSVMGridSearchCV.predict` method directly on this + use the :meth:`~GenSVMGridSearchCV.predict` method directly on this :class:`.GenSVMGridSearchCV` instance. Also for multiple metric evaluation, the attributes :attr:`best_index_ @@ -623,7 +631,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): and not self.best_params_["kernel"] == "linear" and not "gamma" in self.best_params_ ): - self.best_params_["gamma"] = 1. / X.shape[1] + self.best_params_["gamma"] = 1.0 / X.shape[1] self.best_estimator_ = GenSVM(**self.best_params_) # y_orig because GenSVM fit must know the conversion for predict to # work correctly @@ -711,9 +719,24 @@ def load_grid_tiny(): """ pg = [ - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -16)], "weights": ["unit"]}, - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -18)], "weights": ["unit"]}, - {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -18)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -16)], + "weights": ["unit"], + }, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [5.0], @@ -726,7 +749,12 @@ def load_grid_tiny(): "lmd": [pow(2, -18)], "weights": ["unit"], }, - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -14)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -14)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [0.5], @@ -739,7 +767,12 @@ def load_grid_tiny(): "lmd": [pow(2, -18)], "weights": ["unit"], }, - {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -16)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -16)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [0.5], |
