aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-12-04 13:08:42 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-12-04 13:09:02 +0000
commit0361b805330cbf72263af025a2267b196456f715 (patch)
treea042049071702163d30bce82f9523d85714b1bac
parentRemove is_multimetric flag (diff)
downloadpygensvm-0361b805330cbf72263af025a2267b196456f715.tar.gz
pygensvm-0361b805330cbf72263af025a2267b196456f715.zip
Deal with changing behavior of check_is_fitted
This is going to change in version 0.23, so we might as well inline it now.
-rw-r--r--gensvm/core.py2
-rw-r--r--gensvm/sklearn_util.py7
-rw-r--r--gensvm/util.py13
3 files changed, 20 insertions, 2 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index bfd5d9a..45d59ad 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -16,9 +16,9 @@ from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_array, check_X_y, check_random_state
from sklearn.utils.multiclass import type_of_target
-from sklearn.utils.validation import check_is_fitted
from .cython_wrapper import wrapper
+from .util import check_is_fitted
def _fit_gensvm(
diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py
index 182f257..e23921b 100644
--- a/gensvm/sklearn_util.py
+++ b/gensvm/sklearn_util.py
@@ -207,7 +207,12 @@ def _skl_check_is_fitted(estimator, method_name, refit):
"attribute" % (type(estimator).__name__, method_name)
)
else:
- check_is_fitted(estimator, "best_estimator_")
+ if not hasattr(estimator, "best_estimator_"):
+ raise NotFittedError(
+ "This %s instance is not fitted yet. Call "
+ "'fit' with appropriate arguments before using this "
+ "estimator." % type(estimator).__name__
+ )
def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_):
diff --git a/gensvm/util.py b/gensvm/util.py
index 046f3be..40d0eb1 100644
--- a/gensvm/util.py
+++ b/gensvm/util.py
@@ -8,6 +8,7 @@ Utility functions for GenSVM
import numpy as np
+from sklearn.exceptions import NotFittedError
def get_ranks(a):
"""
@@ -37,3 +38,15 @@ def get_ranks(a):
ranks[~np.isnan(orig)] = count[dense - 1] + 1
ranks[np.isnan(orig)] = np.max(ranks) + 1
return list(ranks)
+
+
+def check_is_fitted(estimator, attribute):
+ msg = (
+ "This %(name)s instance is not fitted yet. Call 'fit' "
+ "with appropriate arguments before using this estimator."
+ )
+ if not hasattr(estimator, "fit"):
+ raise TypeError("%s is not an estimator instance" % (estimator))
+
+ if not hasattr(estimator, attribute):
+ raise NotFittedError(msg % {"name": type(estimator).__name__})