diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2015-01-30 16:22:52 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2015-01-30 16:22:52 +0100 |
| commit | df9c3ca0b62f1a20071bee3a55d24d673c5d11e0 (patch) | |
| tree | d3a2d6be5dfe6e2a4e248ad04dfdbb40852c8f7a /src/gensvm_train_dataset.c | |
| parent | update documentation gensvm structs (diff) | |
| download | gensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.tar.gz gensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.zip | |
first working version of new kernel GenSVM
Diffstat (limited to 'src/gensvm_train_dataset.c')
| -rw-r--r-- | src/gensvm_train_dataset.c | 155 |
1 files changed, 154 insertions, 1 deletions
diff --git a/src/gensvm_train_dataset.c b/src/gensvm_train_dataset.c index 3034bb4..eee4bf9 100644 --- a/src/gensvm_train_dataset.c +++ b/src/gensvm_train_dataset.c @@ -435,6 +435,12 @@ void consistency_repeats(struct Queue *q, long repeats, TrainType traintype) * cross validation */ double cross_validation(struct GenModel *model, struct GenData *data, + long folds) +{ + return 0.0; +} +/* +double cross_validation(struct GenModel *model, struct GenData *data, long folds) { FILE *fid; @@ -487,7 +493,7 @@ double cross_validation(struct GenModel *model, struct GenData *data, return total_perf; } - +*/ /** * @brief Run the grid search for a cross validation dataset * @@ -542,6 +548,147 @@ void start_training_cv(struct Queue *q) gensvm_free_model(model); } + +bool kernel_changed(struct Task *newtask, struct Task *oldtask) +{ + if (oldtask == NULL) + return true; + int i; + if (newtask->kerneltype != oldtask->kerneltype) { + return true; + } else if (newtask->kerneltype == K_POLY) { + for (i=0; i<3; i++) + if (newtask->kernelparam[i] != oldtask->kernelparam[i]) + return true; + return false; + } else if (newtask->kerneltype == K_RBF) { + if (newtask->kernelparam[0] != oldtask->kernelparam[0]) + return true; + return false; + } else if (newtask->kerneltype == K_SIGMOID) { + for (i=0; i<2; i++) + if (newtask->kernelparam[i] != oldtask->kernelparam[i]) + return true; + return false; + } + return false; +} + + +void start_training(struct Queue *q) +{ + int f, folds; + double perf, current_max = 0; + struct Task *task = get_next_task(q); + struct Task *prevtask = NULL; + struct GenModel *model = gensvm_init_model(); + clock_t main_s, main_e, loop_s, loop_e; + + // in principle this can change between tasks, but this shouldn't be + // the case TODO + folds = task->folds; + + model->n = 0; + model->m = task->train_data->m; + model->K = task->train_data->K; + gensvm_allocate_model(model); + gensvm_seed_model_V(NULL, model, task->train_data); + + long *cv_idx = Calloc(long, task->train_data->n); + gensvm_make_cv_split(task->train_data->n, task->folds, cv_idx); + + struct GenData **train_folds = Malloc(struct GenData *, task->folds); + struct GenData **test_folds = Malloc(struct GenData *, task->folds); + for (f=0; f<folds; f++) { + train_folds[f] = gensvm_init_data(); + test_folds[f] = gensvm_init_data(); + gensvm_get_tt_split(task->train_data, train_folds[f], + test_folds[f], cv_idx, f); + } + + main_s = clock(); + while (task) { + print_progress_string(task, q->N); + make_model_from_task(task, model); + + if (kernel_changed(task, prevtask)) { + note("*"); + for (f=0; f<folds; f++) { + gensvm_kernel_preprocess(model, + train_folds[f]); + gensvm_kernel_postprocess(model, + train_folds[f], test_folds[f]); + } + note("*"); + } + + loop_s = clock(); + perf = gensvm_cross_validation(model, train_folds, test_folds, + folds, task->train_data->n); + loop_e = clock(); + current_max = maximum(current_max, perf); + + note("\t%3.3f%% (%3.3fs)\t(best = %3.3f%%)\n", perf, + elapsed_time(loop_s, loop_e), current_max); + + q->tasks[task->ID]->performance = perf; + prevtask = task; + task = get_next_task(q); + } + main_e = clock(); + + note("\nTotal elapsed training time: %8.8f seconds\n", + elapsed_time(main_s, main_e)); + + gensvm_free_model(model); + for (f=0; f<folds; f++) { + gensvm_free_data(train_folds[f]); + gensvm_free_data(test_folds[f]); + } + free(train_folds); + free(test_folds); +} + + +double gensvm_cross_validation(struct GenModel *model, + struct GenData **train_folds, struct GenData **test_folds, + int folds, long n_total) +{ + FILE *fid; + + int f; + long *predy; + double performance, total_perf = 0; + + for (f=0; f<folds; f++) { + // reallocate model in case dimensions differ with data + gensvm_reallocate_model(model, train_folds[f]->n, + train_folds[f]->r); + + // initialize object weights + gensvm_initialize_weights(train_folds[f], model); + + // train the model (surpressing output) + fid = GENSVM_OUTPUT_FILE; + GENSVM_OUTPUT_FILE = NULL; + gensvm_optimize(model, train_folds[f]); + GENSVM_OUTPUT_FILE = fid; + + // calculate prediction performance on test set + predy = Calloc(long, test_folds[f]->n); + gensvm_predict_labels(test_folds[f], model, predy); + performance = gensvm_prediction_perf(test_folds[f], predy); + total_perf += performance * test_folds[f]->n; + + free(predy); + } + + total_perf /= ((double) n_total); + + return total_perf; +} + + /** * @brief Run the grid search for a train/test dataset * @@ -563,6 +710,11 @@ void start_training_cv(struct Queue *q) */ void start_training_tt(struct Queue *q) { + return; +} +/* +void start_training_tt(struct Queue *q) +{ FILE *fid; long c = 0; @@ -628,6 +780,7 @@ void start_training_tt(struct Queue *q) free(task); gensvm_free_model(seed_model); } +*/ /** * @brief Free the Queue struct |
