aboutsummaryrefslogtreecommitdiff
path: root/src/gensvm_cross_validation.c
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2016-11-03 15:55:03 +0100
committerGertjan van den Burg <burg@ese.eur.nl>2016-11-03 15:55:03 +0100
commitc3edde20d385614f0016b74e03575344b7c5081a (patch)
tree314d386874ea60dccf8e111fa856bac06c9f656a /src/gensvm_cross_validation.c
parentupdate copyright information (diff)
downloadgensvm-c3edde20d385614f0016b74e03575344b7c5081a.tar.gz
gensvm-c3edde20d385614f0016b74e03575344b7c5081a.zip
prepare for gridsearch unit testing
Diffstat (limited to 'src/gensvm_cross_validation.c')
-rw-r--r--src/gensvm_cross_validation.c91
1 files changed, 91 insertions, 0 deletions
diff --git a/src/gensvm_cross_validation.c b/src/gensvm_cross_validation.c
new file mode 100644
index 0000000..2fe6198
--- /dev/null
+++ b/src/gensvm_cross_validation.c
@@ -0,0 +1,91 @@
+/**
+ * @file gensvm_cross_validation.c
+ * @author G.J.J. van den Burg
+ * @date 2016-10-24
+ * @brief Function for running cross validation on GenModel
+ *
+ * @copyright
+ Copyright 2016, G.J.J. van den Burg.
+
+ This file is part of GenSVM.
+
+ GenSVM is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ GenSVM is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
+
+ */
+
+#include "gensvm_cross_validation.h"
+
+extern FILE *GENSVM_OUTPUT_FILE;
+
+/**
+ * @brief Run cross validation with a given set of train/test folds
+ *
+ * @details
+ * This cross validation function uses predefined train/test splits. Also, the
+ * the optimal parameters GenModel::V of a previous fold as initial conditions
+ * for GenModel::V of the next fold.
+ *
+ * @note
+ * This function always sets the output stream defined in GENSVM_OUTPUT_FILE
+ * to NULL, to ensure gensvm_optimize() doesn't print too much.
+ *
+ * @param[in] model GenModel with the configuration to train
+ * @param[in] train_folds array of training datasets
+ * @param[in] test_folds array of test datasets
+ * @param[in] folds number of folds
+ * @param[in] n_total number of objects in the union of the train
+ * datasets
+ * @return performance (hitrate) of the configuration on
+ * cross validation
+ */
+double gensvm_cross_validation(struct GenModel *model,
+ struct GenData **train_folds, struct GenData **test_folds,
+ int folds, long n_total)
+{
+ int f;
+ long *predy = NULL;
+ double performance, total_perf = 0;
+
+ // make sure that gensvm_optimize() is silent.
+ FILE *fid = GENSVM_OUTPUT_FILE;
+ GENSVM_OUTPUT_FILE = NULL;
+
+ // run cross-validation
+ for (f=0; f<folds; f++) {
+ // reallocate model in case dimensions differ with data
+ gensvm_reallocate_model(model, train_folds[f]->n,
+ train_folds[f]->r);
+
+ // initialize object weights
+ gensvm_initialize_weights(train_folds[f], model);
+
+ // train the model (surpressing output)
+ gensvm_optimize(model, train_folds[f]);
+
+ // calculate prediction performance on test set
+ predy = Calloc(long, test_folds[f]->n);
+ gensvm_predict_labels(test_folds[f], model, predy);
+ performance = gensvm_prediction_perf(test_folds[f], predy);
+ total_perf += performance * test_folds[f]->n;
+
+ free(predy);
+ }
+
+ total_perf /= ((double) n_total);
+
+ // reset the output stream
+ GENSVM_OUTPUT_FILE = fid;
+
+ return total_perf;
+}