diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 22:56:52 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 22:56:52 +0000 |
| commit | 0ac62ea5e6d4610c7c601381a5132caa89b08a00 (patch) | |
| tree | 82499281f665706fdc3f65dd51ae30b0668e910b | |
| parent | Remove unnecessary lines from manifest (diff) | |
| download | pygensvm-0ac62ea5e6d4610c7c601381a5132caa89b08a00.tar.gz pygensvm-0ac62ea5e6d4610c7c601381a5132caa89b08a00.zip | |
Move some stuff around for sklearn
This will need to be fixed later,
the _MultimetricScorer needs to be
removed.
| -rw-r--r-- | gensvm/sklearn_util.py | 48 |
1 files changed, 41 insertions, 7 deletions
diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py index eb8ceb6..d4566f4 100644 --- a/gensvm/sklearn_util.py +++ b/gensvm/sklearn_util.py @@ -14,11 +14,15 @@ required. """ +import numbers import numpy as np from collections import defaultdict +from contextlib import suppress from functools import partial +from sklearn.metrics._scorer import _MultimetricScorer + from .core import GenSVM from .util import get_ranks @@ -62,11 +66,9 @@ DAMAGE. """ from sklearn.exceptions import NotFittedError -from sklearn.externals import six from sklearn.metrics.scorer import _check_multimetric_scoring from sklearn.model_selection._validation import _aggregate_score_dicts from sklearn.utils.fixes import MaskedArray -from sklearn.utils.validation import check_is_fitted def _skl_format_cv_results( @@ -177,10 +179,7 @@ def _skl_check_scorers(scoring, refit): ) if multimetric_: if refit is not False and ( - not isinstance(refit, six.string_types) - or - # This will work for both dict / list (tuple) - refit not in scorers + not isinstance(refit, str) or refit not in scorers ): raise ValueError( "For multi-metric scoring, the parameter " @@ -188,7 +187,7 @@ def _skl_check_scorers(scoring, refit): "to refit an estimator with the best " "parameter setting on the whole data and " "make the best_* attributes " - "available for that metric. If this is not " + "available for that metric. kjIf this is not " "needed, refit should be set to False " "explicitly. %r was passed." % refit ) @@ -247,3 +246,38 @@ def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_): ) score = scorer_[refit] if multimetric_ else scorer_ return score(best_estimator_, X, y) + + +def _skl_score(estimator, X_test, y_test, scorer): + """Compute the score(s) of an estimator on a given test set. + Will return a dict of floats if `scorer` is a dict, otherwise a single + float is returned. + """ + if isinstance(scorer, dict): + # will cache method calls if needed. scorer() returns a dict + scorer = _MultimetricScorer(**scorer) + if y_test is None: + scores = scorer(estimator, X_test) + else: + scores = scorer(estimator, X_test, y_test) + + error_msg = ( + "scoring must return a number, got %s (%s) " "instead. (scorer=%s)" + ) + if isinstance(scores, dict): + for name, score in scores.items(): + if hasattr(score, "item"): + with suppress(ValueError): + # e.g. unwrap memmapped scalars + score = score.item() + if not isinstance(score, numbers.Number): + raise ValueError(error_msg % (score, type(score), name)) + scores[name] = score + else: # scalar + if hasattr(scores, "item"): + with suppress(ValueError): + # e.g. unwrap memmapped scalars + scores = scores.item() + if not isinstance(scores, numbers.Number): + raise ValueError(error_msg % (scores, type(scores), scorer)) + return scores |
