diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 16:57:07 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-01-15 16:57:07 +0000 |
| commit | 54506be3f8e89f0a3f52242863dc3cbb4186bd31 (patch) | |
| tree | 765af9be62004bb9baa8f1d1afc2a4213eb2f800 /gensvm | |
| parent | Update submodule (diff) | |
| download | pygensvm-54506be3f8e89f0a3f52242863dc3cbb4186bd31.tar.gz pygensvm-54506be3f8e89f0a3f52242863dc3cbb4186bd31.zip | |
Update grid search helper to add verbosity
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 2 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 4 | ||||
| -rw-r--r-- | gensvm/gridsearch.py | 8 |
3 files changed, 11 insertions, 3 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index 6a896aa..e3e28a5 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -131,7 +131,7 @@ cdef extern from "gensvm_helper.c": void gensvm_predict(char *, char *, long, long, long, char *) nogil void gensvm_predict_kernels(char *, char *, char *, long, long, long, long, long, long, int, double, double, double, double, char *) nogil - void gensvm_train_q_helper(GenQueue *, char *, int) nogil + void gensvm_train_q_helper(GenQueue *, char *, int, int) nogil void set_queue(GenQueue *, long, GenTask **) double get_task_duration(GenTask *) double get_task_performance(GenTask *) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index e138620..a321102 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -179,6 +179,7 @@ def grid_wrap( int store_predictions, np.ndarray[np.int_t, ndim=1, mode='c'] cv_idx, int n_folds, + int verbosity, ): """ """ @@ -233,7 +234,8 @@ def grid_wrap( set_queue(queue, n_tasks, tasks) with nogil: - gensvm_train_q_helper(queue, cv_idx.data, store_predictions) + gensvm_train_q_helper(queue, cv_idx.data, store_predictions, + verbosity) cdef np.ndarray[np.int_t, ndim=1, mode='c'] pred diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index bf4b9ce..fbb9168 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -245,7 +245,13 @@ def _fit_grid_gensvm( fold_idx += 1 results_ = wrapper.grid_wrap( - X, y, candidate_params, int(store_predictions), cv_idx, int(n_folds) + X, + y, + candidate_params, + int(store_predictions), + cv_idx, + int(n_folds), + int(verbose), ) cv_results_ = _format_results(results_, cv_idx, y, scorers, iid) |
