aboutsummaryrefslogtreecommitdiff
path: root/src/gensvm_train_dataset.c
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2015-01-30 16:22:52 +0100
committerGertjan van den Burg <burg@ese.eur.nl>2015-01-30 16:22:52 +0100
commitdf9c3ca0b62f1a20071bee3a55d24d673c5d11e0 (patch)
treed3a2d6be5dfe6e2a4e248ad04dfdbb40852c8f7a /src/gensvm_train_dataset.c
parentupdate documentation gensvm structs (diff)
downloadgensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.tar.gz
gensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.zip
first working version of new kernel GenSVM
Diffstat (limited to 'src/gensvm_train_dataset.c')
-rw-r--r--src/gensvm_train_dataset.c155
1 files changed, 154 insertions, 1 deletions
diff --git a/src/gensvm_train_dataset.c b/src/gensvm_train_dataset.c
index 3034bb4..eee4bf9 100644
--- a/src/gensvm_train_dataset.c
+++ b/src/gensvm_train_dataset.c
@@ -435,6 +435,12 @@ void consistency_repeats(struct Queue *q, long repeats, TrainType traintype)
* cross validation
*/
double cross_validation(struct GenModel *model, struct GenData *data,
+ long folds)
+{
+ return 0.0;
+}
+/*
+double cross_validation(struct GenModel *model, struct GenData *data,
long folds)
{
FILE *fid;
@@ -487,7 +493,7 @@ double cross_validation(struct GenModel *model, struct GenData *data,
return total_perf;
}
-
+*/
/**
* @brief Run the grid search for a cross validation dataset
*
@@ -542,6 +548,147 @@ void start_training_cv(struct Queue *q)
gensvm_free_model(model);
}
+
+bool kernel_changed(struct Task *newtask, struct Task *oldtask)
+{
+ if (oldtask == NULL)
+ return true;
+ int i;
+ if (newtask->kerneltype != oldtask->kerneltype) {
+ return true;
+ } else if (newtask->kerneltype == K_POLY) {
+ for (i=0; i<3; i++)
+ if (newtask->kernelparam[i] != oldtask->kernelparam[i])
+ return true;
+ return false;
+ } else if (newtask->kerneltype == K_RBF) {
+ if (newtask->kernelparam[0] != oldtask->kernelparam[0])
+ return true;
+ return false;
+ } else if (newtask->kerneltype == K_SIGMOID) {
+ for (i=0; i<2; i++)
+ if (newtask->kernelparam[i] != oldtask->kernelparam[i])
+ return true;
+ return false;
+ }
+ return false;
+}
+
+
+void start_training(struct Queue *q)
+{
+ int f, folds;
+ double perf, current_max = 0;
+ struct Task *task = get_next_task(q);
+ struct Task *prevtask = NULL;
+ struct GenModel *model = gensvm_init_model();
+ clock_t main_s, main_e, loop_s, loop_e;
+
+ // in principle this can change between tasks, but this shouldn't be
+ // the case TODO
+ folds = task->folds;
+
+ model->n = 0;
+ model->m = task->train_data->m;
+ model->K = task->train_data->K;
+ gensvm_allocate_model(model);
+ gensvm_seed_model_V(NULL, model, task->train_data);
+
+ long *cv_idx = Calloc(long, task->train_data->n);
+ gensvm_make_cv_split(task->train_data->n, task->folds, cv_idx);
+
+ struct GenData **train_folds = Malloc(struct GenData *, task->folds);
+ struct GenData **test_folds = Malloc(struct GenData *, task->folds);
+ for (f=0; f<folds; f++) {
+ train_folds[f] = gensvm_init_data();
+ test_folds[f] = gensvm_init_data();
+ gensvm_get_tt_split(task->train_data, train_folds[f],
+ test_folds[f], cv_idx, f);
+ }
+
+ main_s = clock();
+ while (task) {
+ print_progress_string(task, q->N);
+ make_model_from_task(task, model);
+
+ if (kernel_changed(task, prevtask)) {
+ note("*");
+ for (f=0; f<folds; f++) {
+ gensvm_kernel_preprocess(model,
+ train_folds[f]);
+ gensvm_kernel_postprocess(model,
+ train_folds[f], test_folds[f]);
+ }
+ note("*");
+ }
+
+ loop_s = clock();
+ perf = gensvm_cross_validation(model, train_folds, test_folds,
+ folds, task->train_data->n);
+ loop_e = clock();
+ current_max = maximum(current_max, perf);
+
+ note("\t%3.3f%% (%3.3fs)\t(best = %3.3f%%)\n", perf,
+ elapsed_time(loop_s, loop_e), current_max);
+
+ q->tasks[task->ID]->performance = perf;
+ prevtask = task;
+ task = get_next_task(q);
+ }
+ main_e = clock();
+
+ note("\nTotal elapsed training time: %8.8f seconds\n",
+ elapsed_time(main_s, main_e));
+
+ gensvm_free_model(model);
+ for (f=0; f<folds; f++) {
+ gensvm_free_data(train_folds[f]);
+ gensvm_free_data(test_folds[f]);
+ }
+ free(train_folds);
+ free(test_folds);
+}
+
+
+double gensvm_cross_validation(struct GenModel *model,
+ struct GenData **train_folds, struct GenData **test_folds,
+ int folds, long n_total)
+{
+ FILE *fid;
+
+ int f;
+ long *predy;
+ double performance, total_perf = 0;
+
+ for (f=0; f<folds; f++) {
+ // reallocate model in case dimensions differ with data
+ gensvm_reallocate_model(model, train_folds[f]->n,
+ train_folds[f]->r);
+
+ // initialize object weights
+ gensvm_initialize_weights(train_folds[f], model);
+
+ // train the model (surpressing output)
+ fid = GENSVM_OUTPUT_FILE;
+ GENSVM_OUTPUT_FILE = NULL;
+ gensvm_optimize(model, train_folds[f]);
+ GENSVM_OUTPUT_FILE = fid;
+
+ // calculate prediction performance on test set
+ predy = Calloc(long, test_folds[f]->n);
+ gensvm_predict_labels(test_folds[f], model, predy);
+ performance = gensvm_prediction_perf(test_folds[f], predy);
+ total_perf += performance * test_folds[f]->n;
+
+ free(predy);
+ }
+
+ total_perf /= ((double) n_total);
+
+ return total_perf;
+}
+
+
/**
* @brief Run the grid search for a train/test dataset
*
@@ -563,6 +710,11 @@ void start_training_cv(struct Queue *q)
*/
void start_training_tt(struct Queue *q)
{
+ return;
+}
+/*
+void start_training_tt(struct Queue *q)
+{
FILE *fid;
long c = 0;
@@ -628,6 +780,7 @@ void start_training_tt(struct Queue *q)
free(task);
gensvm_free_model(seed_model);
}
+*/
/**
* @brief Free the Queue struct