aboutsummaryrefslogtreecommitdiff
path: root/gensvm/sklearn_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/sklearn_util.py')
-rw-r--r--gensvm/sklearn_util.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py
index 182f257..eb8ceb6 100644
--- a/gensvm/sklearn_util.py
+++ b/gensvm/sklearn_util.py
@@ -89,7 +89,9 @@ def _skl_format_cv_results(
score_time,
) = zip(*out)
else:
- (test_score_dicts, test_sample_counts, fit_time, score_time) = zip(*out)
+ (test_score_dicts, test_sample_counts, fit_time, score_time) = zip(
+ *out
+ )
# test_score_dicts and train_score dicts are lists of dictionaries and
# we make them into dict of lists
@@ -160,7 +162,9 @@ def _skl_format_cv_results(
)
if return_train_score:
_store(
- "train_%s" % scorer_name, train_scores[scorer_name], splits=True
+ "train_%s" % scorer_name,
+ train_scores[scorer_name],
+ splits=True,
)
return results
@@ -207,7 +211,12 @@ def _skl_check_is_fitted(estimator, method_name, refit):
"attribute" % (type(estimator).__name__, method_name)
)
else:
- check_is_fitted(estimator, "best_estimator_")
+ if not hasattr(estimator, "best_estimator_"):
+ raise NotFittedError(
+ "This %s instance is not fitted yet. Call "
+ "'fit' with appropriate arguments before using this "
+ "estimator." % type(estimator).__name__
+ )
def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_):