diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 16:06:17 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-03-06 16:06:17 +0000 |
| commit | aee92cf879fbd11975cca4b9cfd2cc110b6cb229 (patch) | |
| tree | a1d14049cc7e2bc34362ab72a5748e37590db65e /gensvm | |
| parent | Update Makefile to current workflow (diff) | |
| parent | Merge branch 'master' into packaging (diff) | |
| download | pygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.tar.gz pygensvm-aee92cf879fbd11975cca4b9cfd2cc110b6cb229.zip | |
Merge branch 'packaging'
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 2 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 25 | ||||
| -rw-r--r-- | gensvm/gridsearch.py | 2 | ||||
| -rw-r--r-- | gensvm/sklearn_util.py | 15 | ||||
| -rw-r--r-- | gensvm/util.py | 13 |
5 files changed, 37 insertions, 20 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index bfd5d9a..45d59ad 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -16,9 +16,9 @@ from sklearn.exceptions import ConvergenceWarning, FitFailedWarning from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_array, check_X_y, check_random_state from sklearn.utils.multiclass import type_of_target -from sklearn.utils.validation import check_is_fitted from .cython_wrapper import wrapper +from .util import check_is_fitted def _fit_gensvm( diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 009e70b..3a85e92 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -88,8 +88,7 @@ def train_wrap( raise ValueError(error_repl) # Do the actual training - with nogil: - gensvm_train(model, data, seed_model) + gensvm_train(model, data, seed_model) # update the number of variables (this may have changed due to kernel) n_var = get_m(model) @@ -134,12 +133,11 @@ def predict_wrap( # output vector cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions - predictions = np.empty((n_test_obs, ), dtype=np.int) + predictions = np.empty((n_test_obs, ), dtype=np.int_) # do the prediction - with nogil: - gensvm_predict(X.data, V.data, n_test_obs, n_var, n_class, - predictions.data) + gensvm_predict(X.data, V.data, n_test_obs, n_var, n_class, + predictions.data) return predictions @@ -174,10 +172,9 @@ def predict_kernels_wrap( cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions predictions = np.empty((n_obs_test, ), dtype=np.int) - with nogil: - gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows, - V_cols, n_obs_train, n_obs_test, n_var, n_class, kernel_idx, - gamma, coef, degree, kernel_eigen_cutoff, predictions.data) + gensvm_predict_kernels(Xtest.data, Xtrain.data, V.data, V_rows, V_cols, + n_obs_train, n_obs_test, n_var, n_class, kernel_idx, gamma, coef, + degree, kernel_eigen_cutoff, predictions.data) return predictions @@ -243,9 +240,7 @@ def grid_wrap( set_queue(queue, n_tasks, tasks) - with nogil: - gensvm_train_q_helper(queue, cv_idx.data, store_predictions, - verbosity) + gensvm_train_q_helper(queue, cv_idx.data, store_predictions, verbosity) cdef np.ndarray[np.int_t, ndim=1, mode='c'] pred cdef np.ndarray[np.double_t, ndim=1, mode='c'] dur @@ -264,12 +259,12 @@ def grid_wrap( results['params'].append(candidate_params[ID]) results['scores'].append(get_task_performance(tasks[ID])) if store_predictions: - pred = np.zeros((n_obs, ), dtype=np.int) + pred = np.zeros((n_obs, ), dtype=np.int_) copy_task_predictions(tasks[ID], pred.data, n_obs) results['predictions'].append(pred.copy()) dur = np.zeros((n_folds, ), dtype=np.double) copy_task_durations(tasks[ID], dur.data, n_folds) - results['durations'].append(dur.copy()) + results['durations'].append(dur) gensvm_free_queue(queue) free_data(data) diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index b27a347..22125a4 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -116,7 +116,7 @@ def _wrap_score(y_pred, y_true, scorers, is_multimetric): results["score"] = np.nan else: estimator = _MockEstimator(y_pred) - results = _score(estimator, None, y_true, scorers, is_multimetric) + results = _score(estimator, None, y_true, scorers) score_time = time.time() - start_time return results, score_time diff --git a/gensvm/sklearn_util.py b/gensvm/sklearn_util.py index 182f257..eb8ceb6 100644 --- a/gensvm/sklearn_util.py +++ b/gensvm/sklearn_util.py @@ -89,7 +89,9 @@ def _skl_format_cv_results( score_time, ) = zip(*out) else: - (test_score_dicts, test_sample_counts, fit_time, score_time) = zip(*out) + (test_score_dicts, test_sample_counts, fit_time, score_time) = zip( + *out + ) # test_score_dicts and train_score dicts are lists of dictionaries and # we make them into dict of lists @@ -160,7 +162,9 @@ def _skl_format_cv_results( ) if return_train_score: _store( - "train_%s" % scorer_name, train_scores[scorer_name], splits=True + "train_%s" % scorer_name, + train_scores[scorer_name], + splits=True, ) return results @@ -207,7 +211,12 @@ def _skl_check_is_fitted(estimator, method_name, refit): "attribute" % (type(estimator).__name__, method_name) ) else: - check_is_fitted(estimator, "best_estimator_") + if not hasattr(estimator, "best_estimator_"): + raise NotFittedError( + "This %s instance is not fitted yet. Call " + "'fit' with appropriate arguments before using this " + "estimator." % type(estimator).__name__ + ) def _skl_grid_score(X, y, scorer_, best_estimator_, refit, multimetric_): diff --git a/gensvm/util.py b/gensvm/util.py index 046f3be..40d0eb1 100644 --- a/gensvm/util.py +++ b/gensvm/util.py @@ -8,6 +8,7 @@ Utility functions for GenSVM import numpy as np +from sklearn.exceptions import NotFittedError def get_ranks(a): """ @@ -37,3 +38,15 @@ def get_ranks(a): ranks[~np.isnan(orig)] = count[dense - 1] + 1 ranks[np.isnan(orig)] = np.max(ranks) + 1 return list(ranks) + + +def check_is_fitted(estimator, attribute): + msg = ( + "This %(name)s instance is not fitted yet. Call 'fit' " + "with appropriate arguments before using this estimator." + ) + if not hasattr(estimator, "fit"): + raise TypeError("%s is not an estimator instance" % (estimator)) + + if not hasattr(estimator, attribute): + raise NotFittedError(msg % {"name": type(estimator).__name__}) |
