aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 22:56:52 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 22:56:52 +0000
commit0ac62ea5e6d4610c7c601381a5132caa89b08a00 (patch)
tree82499281f665706fdc3f65dd51ae30b0668e910b
parentRemove unnecessary lines from manifest (diff)
downloadpygensvm-0ac62ea5e6d4610c7c601381a5132caa89b08a00.tar.gz
pygensvm-0ac62ea5e6d4610c7c601381a5132caa89b08a00.zip
Move some stuff around for sklearn
This will need to be fixed later, the _MultimetricScorer needs to be removed.
-rw-r--r--gensvm/sklearn_util.py48
1 files changed, 41 insertions, 7 deletions
diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py
index eb8ceb6..d4566f4 100644
--- a/gensvm/sklearn_util.py
+++ b/gensvm/sklearn_util.py
@@ -14,11 +14,15 @@ required.
"""
+import numbers
import numpy as np
from collections import defaultdict
+from contextlib import suppress
from functools import partial
+from sklearn.metrics._scorer import _MultimetricScorer
+
from .core import GenSVM
from .util import get_ranks
@@ -62,11 +66,9 @@ DAMAGE.
"""
from sklearn.exceptions import NotFittedError
-from sklearn.externals import six
from sklearn.metrics.scorer import _check_multimetric_scoring
from sklearn.model_selection._validation import _aggregate_score_dicts
from sklearn.utils.fixes import MaskedArray
-from sklearn.utils.validation import check_is_fitted
def _skl_format_cv_results(
@@ -177,10 +179,7 @@ def _skl_check_scorers(scoring, refit):
)
if multimetric_:
if refit is not False and (
- not isinstance(refit, six.string_types)
- or
- # This will work for both dict / list (tuple)
- refit not in scorers
+ not isinstance(refit, str) or refit not in scorers
):
raise ValueError(
"For multi-metric scoring, the parameter "
@@ -188,7 +187,7 @@ def _skl_check_scorers(scoring, refit):
"to refit an estimator with the best "
"parameter setting on the whole data and "
"make the best_* attributes "
- "available for that metric. If this is not "
+ "available for that metric. kjIf this is not "
"needed, refit should be set to False "
"explicitly. %r was passed." % refit
)
@@ -247,3 +246,38 @@ def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_):
)
score = scorer_[refit] if multimetric_ else scorer_
return score(best_estimator_, X, y)
+
+
+def _skl_score(estimator, X_test, y_test, scorer):
+ """Compute the score(s) of an estimator on a given test set.
+ Will return a dict of floats if `scorer` is a dict, otherwise a single
+ float is returned.
+ """
+ if isinstance(scorer, dict):
+ # will cache method calls if needed. scorer() returns a dict
+ scorer = _MultimetricScorer(**scorer)
+ if y_test is None:
+ scores = scorer(estimator, X_test)
+ else:
+ scores = scorer(estimator, X_test, y_test)
+
+ error_msg = (
+ "scoring must return a number, got %s (%s) " "instead. (scorer=%s)"
+ )
+ if isinstance(scores, dict):
+ for name, score in scores.items():
+ if hasattr(score, "item"):
+ with suppress(ValueError):
+ # e.g. unwrap memmapped scalars
+ score = score.item()
+ if not isinstance(score, numbers.Number):
+ raise ValueError(error_msg % (score, type(score), name))
+ scores[name] = score
+ else: # scalar
+ if hasattr(scores, "item"):
+ with suppress(ValueError):
+ # e.g. unwrap memmapped scalars
+ scores = scores.item()
+ if not isinstance(scores, numbers.Number):
+ raise ValueError(error_msg % (scores, type(scores), scorer))
+ return scores