diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-12-04 13:08:42 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-12-04 13:09:02 +0000 |
| commit | 0361b805330cbf72263af025a2267b196456f715 (patch) | |
| tree | a042049071702163d30bce82f9523d85714b1bac /gensvm | |
| parent | Remove is_multimetric flag (diff) | |
| download | pygensvm-0361b805330cbf72263af025a2267b196456f715.tar.gz pygensvm-0361b805330cbf72263af025a2267b196456f715.zip | |
Deal with changing behavior of check_is_fitted
This is going to change in version 0.23, so we
might as well inline it now.
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 2 | ||||
| -rw-r--r-- | gensvm/sklearn_util.py | 7 | ||||
| -rw-r--r-- | gensvm/util.py | 13 |
3 files changed, 20 insertions, 2 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index bfd5d9a..45d59ad 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -16,9 +16,9 @@ from sklearn.exceptions import ConvergenceWarning, FitFailedWarning from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_array, check_X_y, check_random_state from sklearn.utils.multiclass import type_of_target -from sklearn.utils.validation import check_is_fitted from .cython_wrapper import wrapper +from .util import check_is_fitted def _fit_gensvm( diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py index 182f257..e23921b 100644 --- a/gensvm/sklearn_util.py +++ b/gensvm/sklearn_util.py @@ -207,7 +207,12 @@ def _skl_check_is_fitted(estimator, method_name, refit): "attribute" % (type(estimator).__name__, method_name) ) else: - check_is_fitted(estimator, "best_estimator_") + if not hasattr(estimator, "best_estimator_"): + raise NotFittedError( + "This %s instance is not fitted yet. Call " + "'fit' with appropriate arguments before using this " + "estimator." % type(estimator).__name__ + ) def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_): diff --git a/gensvm/util.py b/gensvm/util.py index 046f3be..40d0eb1 100644 --- a/gensvm/util.py +++ b/gensvm/util.py @@ -8,6 +8,7 @@ Utility functions for GenSVM import numpy as np +from sklearn.exceptions import NotFittedError def get_ranks(a): """ @@ -37,3 +38,15 @@ def get_ranks(a): ranks[~np.isnan(orig)] = count[dense - 1] + 1 ranks[np.isnan(orig)] = np.max(ranks) + 1 return list(ranks) + + +def check_is_fitted(estimator, attribute): + msg = ( + "This %(name)s instance is not fitted yet. Call 'fit' " + "with appropriate arguments before using this estimator." + ) + if not hasattr(estimator, "fit"): + raise TypeError("%s is not an estimator instance" % (estimator)) + + if not hasattr(estimator, attribute): + raise NotFittedError(msg % {"name": type(estimator).__name__}) |
