aboutsummaryrefslogtreecommitdiff
path: root/gensvm/util.py
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 16:06:17 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2020-03-06 16:06:17 +0000
commitaee92cf879fbd11975cca4b9cfd2cc110b6cb229 (patch)
treea1d14049cc7e2bc34362ab72a5748e37590db65e /gensvm/util.py
parentUpdate Makefile to current workflow (diff)
parentMerge branch 'master' into packaging (diff)
downloadpygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.tar.gz
pygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.zip
Merge branch 'packaging'
Diffstat (limited to 'gensvm/util.py')
-rw-r--r--gensvm/util.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/gensvm/util.py b/gensvm/util.py
index 046f3be..40d0eb1 100644
--- a/gensvm/util.py
+++ b/gensvm/util.py
@@ -8,6 +8,7 @@ Utility functions for GenSVM
import numpy as np
+from sklearn.exceptions import NotFittedError
def get_ranks(a):
"""
@@ -37,3 +38,15 @@ def get_ranks(a):
ranks[~np.isnan(orig)] = count[dense - 1] + 1
ranks[np.isnan(orig)] = np.max(ranks) + 1
return list(ranks)
+
+
+def check_is_fitted(estimator, attribute):
+ msg = (
+ "This %(name)s instance is not fitted yet. Call 'fit' "
+ "with appropriate arguments before using this estimator."
+ )
+ if not hasattr(estimator, "fit"):
+ raise TypeError("%s is not an estimator instance" % (estimator))
+
+ if not hasattr(estimator, attribute):
+ raise NotFittedError(msg % {"name": type(estimator).__name__})