diff options
Diffstat (limited to 'gensvm/gridsearch.py')
| -rw-r--r-- | gensvm/gridsearch.py | 59 |
1 files changed, 46 insertions, 13 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 22125a4..550ae64 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -65,7 +65,7 @@ def _validate_param_grid(param_grid): """Check if the parameter values are valid This basically does the same checks as in the constructor of the - :class:`core.GenSVM` class, but for the entire parameter grid. + :class:`~.core.GenSVM` class, but for the entire parameter grid. """ # the conditions that the parameters must satisfy @@ -169,16 +169,24 @@ def _format_results( score_time = 0 if return_train_score: - train_pred = predictions[cv_idx != test_idx,] - y_train = true_y[cv_idx != test_idx,] + train_pred = predictions[ + cv_idx != test_idx, + ] + y_train = true_y[ + cv_idx != test_idx, + ] train_score, score_t = _wrap_score( train_pred, y_train, scorers, is_multimetric ) score_time += score_t ret.append(train_score) - test_pred = predictions[cv_idx == test_idx,] - y_test = true_y[cv_idx == test_idx,] + test_pred = predictions[ + cv_idx == test_idx, + ] + y_test = true_y[ + cv_idx == test_idx, + ] test_score, score_t = _wrap_score( test_pred, y_test, scorers, is_multimetric ) @@ -232,7 +240,7 @@ def _fit_grid_gensvm( Returns ------- cv_results_ : dict - The cross validation results. See :func:`~GenSVMGridSearchCV.fit`. + The cross validation results. See :meth:`~GenSVMGridSearchCV.fit`. """ @@ -349,7 +357,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): The refitted estimator is made available at the `:attr:best_estimator_ <.GenSVMGridSearchCV.best_estimator_>` attribute and allows the user to - use the :func:`~GenSVMGridSearchCV.predict` method directly on this + use the :meth:`~GenSVMGridSearchCV.predict` method directly on this :class:`.GenSVMGridSearchCV` instance. Also for multiple metric evaluation, the attributes :attr:`best_index_ @@ -623,7 +631,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): and not self.best_params_["kernel"] == "linear" and not "gamma" in self.best_params_ ): - self.best_params_["gamma"] = 1. / X.shape[1] + self.best_params_["gamma"] = 1.0 / X.shape[1] self.best_estimator_ = GenSVM(**self.best_params_) # y_orig because GenSVM fit must know the conversion for predict to # work correctly @@ -711,9 +719,24 @@ def load_grid_tiny(): """ pg = [ - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -16)], "weights": ["unit"]}, - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -18)], "weights": ["unit"]}, - {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -18)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -16)], + "weights": ["unit"], + }, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [5.0], @@ -726,7 +749,12 @@ def load_grid_tiny(): "lmd": [pow(2, -18)], "weights": ["unit"], }, - {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -14)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -14)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [0.5], @@ -739,7 +767,12 @@ def load_grid_tiny(): "lmd": [pow(2, -18)], "weights": ["unit"], }, - {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -16)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -16)], + "weights": ["unit"], + }, { "p": [2.0], "kappa": [0.5], |
