diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 16:06:17 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 16:06:17 +0000 |
| commit | aee92cf879fbd11975cca4b9cfd2cc110b6cb229 (patch) | |
| tree | a1d14049cc7e2bc34362ab72a5748e37590db65e /gensvm/util.py | |
| parent | Update Makefile to current workflow (diff) | |
| parent | Merge branch 'master' into packaging (diff) | |
| download | pygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.tar.gz pygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.zip | |
Merge branch 'packaging'
Diffstat (limited to 'gensvm/util.py')
| -rw-r--r-- | gensvm/util.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/gensvm/util.py b/gensvm/util.py index 046f3be..40d0eb1 100644 --- a/gensvm/util.py +++ b/gensvm/util.py @@ -8,6 +8,7 @@ Utility functions for GenSVM import numpy as np +from sklearn.exceptions import NotFittedError def get_ranks(a): """ @@ -37,3 +38,15 @@ def get_ranks(a): ranks[~np.isnan(orig)] = count[dense - 1] + 1 ranks[np.isnan(orig)] = np.max(ranks) + 1 return list(ranks) + + +def check_is_fitted(estimator, attribute): + msg = ( + "This %(name)s instance is not fitted yet. Call 'fit' " + "with appropriate arguments before using this estimator." + ) + if not hasattr(estimator, "fit"): + raise TypeError("%s is not an estimator instance" % (estimator)) + + if not hasattr(estimator, attribute): + raise NotFittedError(msg % {"name": type(estimator).__name__}) |
