From c3edde20d385614f0016b74e03575344b7c5081a Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Thu, 3 Nov 2016 15:55:03 +0100 Subject: prepare for gridsearch unit testing --- src/gensvm_cross_validation.c | 91 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 src/gensvm_cross_validation.c (limited to 'src/gensvm_cross_validation.c') diff --git a/src/gensvm_cross_validation.c b/src/gensvm_cross_validation.c new file mode 100644 index 0000000..2fe6198 --- /dev/null +++ b/src/gensvm_cross_validation.c @@ -0,0 +1,91 @@ +/** + * @file gensvm_cross_validation.c + * @author G.J.J. van den Burg + * @date 2016-10-24 + * @brief Function for running cross validation on GenModel + * + * @copyright + Copyright 2016, G.J.J. van den Burg. + + This file is part of GenSVM. + + GenSVM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + GenSVM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with GenSVM. If not, see . + + */ + +#include "gensvm_cross_validation.h" + +extern FILE *GENSVM_OUTPUT_FILE; + +/** + * @brief Run cross validation with a given set of train/test folds + * + * @details + * This cross validation function uses predefined train/test splits. Also, the + * the optimal parameters GenModel::V of a previous fold as initial conditions + * for GenModel::V of the next fold. + * + * @note + * This function always sets the output stream defined in GENSVM_OUTPUT_FILE + * to NULL, to ensure gensvm_optimize() doesn't print too much. + * + * @param[in] model GenModel with the configuration to train + * @param[in] train_folds array of training datasets + * @param[in] test_folds array of test datasets + * @param[in] folds number of folds + * @param[in] n_total number of objects in the union of the train + * datasets + * @return performance (hitrate) of the configuration on + * cross validation + */ +double gensvm_cross_validation(struct GenModel *model, + struct GenData **train_folds, struct GenData **test_folds, + int folds, long n_total) +{ + int f; + long *predy = NULL; + double performance, total_perf = 0; + + // make sure that gensvm_optimize() is silent. + FILE *fid = GENSVM_OUTPUT_FILE; + GENSVM_OUTPUT_FILE = NULL; + + // run cross-validation + for (f=0; fn, + train_folds[f]->r); + + // initialize object weights + gensvm_initialize_weights(train_folds[f], model); + + // train the model (surpressing output) + gensvm_optimize(model, train_folds[f]); + + // calculate prediction performance on test set + predy = Calloc(long, test_folds[f]->n); + gensvm_predict_labels(test_folds[f], model, predy); + performance = gensvm_prediction_perf(test_folds[f], predy); + total_perf += performance * test_folds[f]->n; + + free(predy); + } + + total_perf /= ((double) n_total); + + // reset the output stream + GENSVM_OUTPUT_FILE = fid; + + return total_perf; +} -- cgit v1.2.3