aboutsummaryrefslogtreecommitdiff
path: root/gensvm/util.py
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-12-04 13:08:42 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-12-04 13:09:02 +0000
commit0361b805330cbf72263af025a2267b196456f715 (patch)
treea042049071702163d30bce82f9523d85714b1bac /gensvm/util.py
parentRemove is_multimetric flag (diff)
downloadpygensvm-0361b805330cbf72263af025a2267b196456f715.tar.gz
pygensvm-0361b805330cbf72263af025a2267b196456f715.zip
Deal with changing behavior of check_is_fitted
This is going to change in version 0.23, so we might as well inline it now.
Diffstat (limited to 'gensvm/util.py')
-rw-r--r--gensvm/util.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/gensvm/util.py b/gensvm/util.py
index 046f3be..40d0eb1 100644
--- a/gensvm/util.py
+++ b/gensvm/util.py
@@ -8,6 +8,7 @@ Utility functions for GenSVM
import numpy as np
+from sklearn.exceptions import NotFittedError
def get_ranks(a):
"""
@@ -37,3 +38,15 @@ def get_ranks(a):
ranks[~np.isnan(orig)] = count[dense - 1] + 1
ranks[np.isnan(orig)] = np.max(ranks) + 1
return list(ranks)
+
+
+def check_is_fitted(estimator, attribute):
+ msg = (
+ "This %(name)s instance is not fitted yet. Call 'fit' "
+ "with appropriate arguments before using this estimator."
+ )
+ if not hasattr(estimator, "fit"):
+ raise TypeError("%s is not an estimator instance" % (estimator))
+
+ if not hasattr(estimator, attribute):
+ raise NotFittedError(msg % {"name": type(estimator).__name__})