From 0ac62ea5e6d4610c7c601381a5132caa89b08a00 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Fri, 6 Mar 2020 22:56:52 +0000 Subject: Move some stuff around for sklearn This will need to be fixed later, the _MultimetricScorer needs to be removed. --- gensvm/sklearn_util.py | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py index eb8ceb6..d4566f4 100644 --- a/gensvm/sklearn_util.py +++ b/gensvm/sklearn_util.py @@ -14,11 +14,15 @@ required. """ +import numbers import numpy as np from collections import defaultdict +from contextlib import suppress from functools import partial +from sklearn.metrics._scorer import _MultimetricScorer + from .core import GenSVM from .util import get_ranks @@ -62,11 +66,9 @@ DAMAGE. """ from sklearn.exceptions import NotFittedError -from sklearn.externals import six from sklearn.metrics.scorer import _check_multimetric_scoring from sklearn.model_selection._validation import _aggregate_score_dicts from sklearn.utils.fixes import MaskedArray -from sklearn.utils.validation import check_is_fitted def _skl_format_cv_results( @@ -177,10 +179,7 @@ def _skl_check_scorers(scoring, refit): ) if multimetric_: if refit is not False and ( - not isinstance(refit, six.string_types) - or - # This will work for both dict / list (tuple) - refit not in scorers + not isinstance(refit, str) or refit not in scorers ): raise ValueError( "For multi-metric scoring, the parameter " @@ -188,7 +187,7 @@ def _skl_check_scorers(scoring, refit): "to refit an estimator with the best " "parameter setting on the whole data and " "make the best_* attributes " - "available for that metric. If this is not " + "available for that metric. kjIf this is not " "needed, refit should be set to False " "explicitly. %r was passed." % refit ) @@ -247,3 +246,38 @@ def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_): ) score = scorer_[refit] if multimetric_ else scorer_ return score(best_estimator_, X, y) + + +def _skl_score(estimator, X_test, y_test, scorer): + """Compute the score(s) of an estimator on a given test set. + Will return a dict of floats if `scorer` is a dict, otherwise a single + float is returned. + """ + if isinstance(scorer, dict): + # will cache method calls if needed. scorer() returns a dict + scorer = _MultimetricScorer(**scorer) + if y_test is None: + scores = scorer(estimator, X_test) + else: + scores = scorer(estimator, X_test, y_test) + + error_msg = ( + "scoring must return a number, got %s (%s) " "instead. (scorer=%s)" + ) + if isinstance(scores, dict): + for name, score in scores.items(): + if hasattr(score, "item"): + with suppress(ValueError): + # e.g. unwrap memmapped scalars + score = score.item() + if not isinstance(score, numbers.Number): + raise ValueError(error_msg % (score, type(score), name)) + scores[name] = score + else: # scalar + if hasattr(scores, "item"): + with suppress(ValueError): + # e.g. unwrap memmapped scalars + scores = scores.item() + if not isinstance(scores, numbers.Number): + raise ValueError(error_msg % (scores, type(scores), scorer)) + return scores -- cgit v1.2.3