aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 12:22:33 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 12:23:49 -0500
commit33b99f5e38acbc0e28832b997e2f15f602bdd9a6 (patch)
treeec00ff8c462e58c3b73bed1f9b2ccca0fc7a79c8 /gensvm
parentUpdate grid search helper to add verbosity (diff)
downloadpygensvm-33b99f5e38acbc0e28832b997e2f15f602bdd9a6.tar.gz
pygensvm-33b99f5e38acbc0e28832b997e2f15f602bdd9a6.zip
Extract durations array
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/cython_wrapper/wrapper.pxd1
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx10
2 files changed, 10 insertions, 1 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd
index e3e28a5..5b097a6 100644
--- a/gensvm/cython_wrapper/wrapper.pxd
+++ b/gensvm/cython_wrapper/wrapper.pxd
@@ -136,3 +136,4 @@ cdef extern from "gensvm_helper.c":
double get_task_duration(GenTask *)
double get_task_performance(GenTask *)
void copy_task_predictions(GenTask *, char *, long)
+ void copy_task_durations(GenTask *, char *, int)
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index a321102..5c58fee 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -238,20 +238,28 @@ def grid_wrap(
verbosity)
cdef np.ndarray[np.int_t, ndim=1, mode='c'] pred
+ cdef np.ndarray[np.double_t, ndim=1, mode='c'] dur
results = dict()
results['params'] = []
results['duration'] = []
results['scores'] = []
+ # predictions: for each task, an array of size n_obs with class
+ # predictions (-1 if missing)
results['predictions'] = []
+ # durations: for each task, an array of size n_folds with duration for
+ # each fold (nan if missing)
+ results['durations'] = []
for ID in range(n_tasks):
results['params'].append(candidate_params[ID])
- results['duration'].append(get_task_duration(tasks[ID]))
results['scores'].append(get_task_performance(tasks[ID]))
if store_predictions:
pred = np.zeros((n_obs, ), dtype=np.int)
copy_task_predictions(tasks[ID], pred.data, n_obs)
results['predictions'].append(pred.copy())
+ dur = np.zeros((n_folds, ), dtype=np.double)
+ copy_task_durations(tasks[ID], dur.data, n_folds)
+ results['durations'].append(dur.copy())
gensvm_free_queue(queue)
free_data(data)