aboutsummaryrefslogtreecommitdiff
path: root/gensvm/gridsearch.py
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/gridsearch.py')
-rw-r--r--gensvm/gridsearch.py59
1 files changed, 46 insertions, 13 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py
index 22125a4..550ae64 100644
--- a/gensvm/gridsearch.py
+++ b/gensvm/gridsearch.py
@@ -65,7 +65,7 @@ def _validate_param_grid(param_grid):
"""Check if the parameter values are valid
This basically does the same checks as in the constructor of the
- :class:`core.GenSVM` class, but for the entire parameter grid.
+ :class:`~.core.GenSVM` class, but for the entire parameter grid.
"""
# the conditions that the parameters must satisfy
@@ -169,16 +169,24 @@ def _format_results(
score_time = 0
if return_train_score:
- train_pred = predictions[cv_idx != test_idx,]
- y_train = true_y[cv_idx != test_idx,]
+ train_pred = predictions[
+ cv_idx != test_idx,
+ ]
+ y_train = true_y[
+ cv_idx != test_idx,
+ ]
train_score, score_t = _wrap_score(
train_pred, y_train, scorers, is_multimetric
)
score_time += score_t
ret.append(train_score)
- test_pred = predictions[cv_idx == test_idx,]
- y_test = true_y[cv_idx == test_idx,]
+ test_pred = predictions[
+ cv_idx == test_idx,
+ ]
+ y_test = true_y[
+ cv_idx == test_idx,
+ ]
test_score, score_t = _wrap_score(
test_pred, y_test, scorers, is_multimetric
)
@@ -232,7 +240,7 @@ def _fit_grid_gensvm(
Returns
-------
cv_results_ : dict
- The cross validation results. See :func:`~GenSVMGridSearchCV.fit`.
+ The cross validation results. See :meth:`~GenSVMGridSearchCV.fit`.
"""
@@ -349,7 +357,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
The refitted estimator is made available at the `:attr:best_estimator_
<.GenSVMGridSearchCV.best_estimator_>` attribute and allows the user to
- use the :func:`~GenSVMGridSearchCV.predict` method directly on this
+ use the :meth:`~GenSVMGridSearchCV.predict` method directly on this
:class:`.GenSVMGridSearchCV` instance.
Also for multiple metric evaluation, the attributes :attr:`best_index_
@@ -623,7 +631,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin):
and not self.best_params_["kernel"] == "linear"
and not "gamma" in self.best_params_
):
- self.best_params_["gamma"] = 1. / X.shape[1]
+ self.best_params_["gamma"] = 1.0 / X.shape[1]
self.best_estimator_ = GenSVM(**self.best_params_)
# y_orig because GenSVM fit must know the conversion for predict to
# work correctly
@@ -711,9 +719,24 @@ def load_grid_tiny():
"""
pg = [
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -16)], "weights": ["unit"]},
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -18)], "weights": ["unit"]},
- {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -18)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -16)],
+ "weights": ["unit"],
+ },
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -18)],
+ "weights": ["unit"],
+ },
+ {
+ "p": [2.0],
+ "kappa": [0.5],
+ "lmd": [pow(2, -18)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [5.0],
@@ -726,7 +749,12 @@ def load_grid_tiny():
"lmd": [pow(2, -18)],
"weights": ["unit"],
},
- {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -14)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [5.0],
+ "lmd": [pow(2, -14)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [0.5],
@@ -739,7 +767,12 @@ def load_grid_tiny():
"lmd": [pow(2, -18)],
"weights": ["unit"],
},
- {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -16)], "weights": ["unit"]},
+ {
+ "p": [2.0],
+ "kappa": [0.5],
+ "lmd": [pow(2, -16)],
+ "weights": ["unit"],
+ },
{
"p": [2.0],
"kappa": [0.5],