diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:22:33 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:23:49 -0500 |
| commit | 33b99f5e38acbc0e28832b997e2f15f602bdd9a6 (patch) | |
| tree | ec00ff8c462e58c3b73bed1f9b2ccca0fc7a79c8 /gensvm | |
| parent | Update grid search helper to add verbosity (diff) | |
| download | pygensvm-33b99f5e38acbc0e28832b997e2f15f602bdd9a6.tar.gz pygensvm-33b99f5e38acbc0e28832b997e2f15f602bdd9a6.zip | |
Extract durations array
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 1 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 10 |
2 files changed, 10 insertions, 1 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index e3e28a5..5b097a6 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -136,3 +136,4 @@ cdef extern from "gensvm_helper.c": double get_task_duration(GenTask *) double get_task_performance(GenTask *) void copy_task_predictions(GenTask *, char *, long) + void copy_task_durations(GenTask *, char *, int) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index a321102..5c58fee 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -238,20 +238,28 @@ def grid_wrap( verbosity) cdef np.ndarray[np.int_t, ndim=1, mode='c'] pred + cdef np.ndarray[np.double_t, ndim=1, mode='c'] dur results = dict() results['params'] = [] results['duration'] = [] results['scores'] = [] + # predictions: for each task, an array of size n_obs with class + # predictions (-1 if missing) results['predictions'] = [] + # durations: for each task, an array of size n_folds with duration for + # each fold (nan if missing) + results['durations'] = [] for ID in range(n_tasks): results['params'].append(candidate_params[ID]) - results['duration'].append(get_task_duration(tasks[ID])) results['scores'].append(get_task_performance(tasks[ID])) if store_predictions: pred = np.zeros((n_obs, ), dtype=np.int) copy_task_predictions(tasks[ID], pred.data, n_obs) results['predictions'].append(pred.copy()) + dur = np.zeros((n_folds, ), dtype=np.double) + copy_task_durations(tasks[ID], dur.data, n_folds) + results['durations'].append(dur.copy()) gensvm_free_queue(queue) free_data(data) |
