/**
* @file gensvm_cross_validation.c
* @author G.J.J. van den Burg
* @date 2016-10-24
* @brief Function for running cross validation on GenModel
*
* @copyright
Copyright 2016, G.J.J. van den Burg.
This file is part of GenSVM.
GenSVM is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
GenSVM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with GenSVM. If not, see .
*/
#include "gensvm_cross_validation.h"
extern FILE *GENSVM_OUTPUT_FILE;
/**
* @brief Run cross validation with a given set of train/test folds
*
* @details
* This cross validation function uses predefined train/test splits. Also, the
* the optimal parameters GenModel::V of a previous fold as initial conditions
* for GenModel::V of the next fold.
*
* @note
* This function always sets the output stream defined in GENSVM_OUTPUT_FILE
* to NULL, to ensure gensvm_optimize() doesn't print too much.
*
* @param[in] model GenModel with the configuration to train
* @param[in] train_folds array of training datasets
* @param[in] test_folds array of test datasets
* @param[in] folds number of folds
* @param[in] n_total number of objects in the union of the train
* datasets
* @return performance (hitrate) of the configuration on
* cross validation
*/
double gensvm_cross_validation(struct GenModel *model,
struct GenData **train_folds, struct GenData **test_folds,
long folds, long n_total)
{
long f;
long *predy = NULL;
double performance, total_perf = 0;
// make sure that gensvm_optimize() is silent.
FILE *fid = GENSVM_OUTPUT_FILE;
GENSVM_OUTPUT_FILE = NULL;
// run cross-validation
for (f=0; fn,
train_folds[f]->r);
// initialize object weights
gensvm_initialize_weights(train_folds[f], model);
// train the model (surpressing output)
gensvm_optimize(model, train_folds[f]);
// calculate prediction performance on test set
predy = Calloc(long, test_folds[f]->n);
gensvm_predict_labels(test_folds[f], model, predy);
performance = gensvm_prediction_perf(test_folds[f], predy);
total_perf += performance * test_folds[f]->n;
free(predy);
}
total_perf /= ((double) n_total);
// reset the output stream
GENSVM_OUTPUT_FILE = fid;
return total_perf;
}