diff options
Diffstat (limited to 'src/gensvm_cross_validation.c')
| -rw-r--r-- | src/gensvm_cross_validation.c | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/src/gensvm_cross_validation.c b/src/gensvm_cross_validation.c new file mode 100644 index 0000000..2fe6198 --- /dev/null +++ b/src/gensvm_cross_validation.c @@ -0,0 +1,91 @@ +/** + * @file gensvm_cross_validation.c + * @author G.J.J. van den Burg + * @date 2016-10-24 + * @brief Function for running cross validation on GenModel + * + * @copyright + Copyright 2016, G.J.J. van den Burg. + + This file is part of GenSVM. + + GenSVM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + GenSVM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with GenSVM. If not, see <http://www.gnu.org/licenses/>. + + */ + +#include "gensvm_cross_validation.h" + +extern FILE *GENSVM_OUTPUT_FILE; + +/** + * @brief Run cross validation with a given set of train/test folds + * + * @details + * This cross validation function uses predefined train/test splits. Also, the + * the optimal parameters GenModel::V of a previous fold as initial conditions + * for GenModel::V of the next fold. + * + * @note + * This function always sets the output stream defined in GENSVM_OUTPUT_FILE + * to NULL, to ensure gensvm_optimize() doesn't print too much. + * + * @param[in] model GenModel with the configuration to train + * @param[in] train_folds array of training datasets + * @param[in] test_folds array of test datasets + * @param[in] folds number of folds + * @param[in] n_total number of objects in the union of the train + * datasets + * @return performance (hitrate) of the configuration on + * cross validation + */ +double gensvm_cross_validation(struct GenModel *model, + struct GenData **train_folds, struct GenData **test_folds, + int folds, long n_total) +{ + int f; + long *predy = NULL; + double performance, total_perf = 0; + + // make sure that gensvm_optimize() is silent. + FILE *fid = GENSVM_OUTPUT_FILE; + GENSVM_OUTPUT_FILE = NULL; + + // run cross-validation + for (f=0; f<folds; f++) { + // reallocate model in case dimensions differ with data + gensvm_reallocate_model(model, train_folds[f]->n, + train_folds[f]->r); + + // initialize object weights + gensvm_initialize_weights(train_folds[f], model); + + // train the model (surpressing output) + gensvm_optimize(model, train_folds[f]); + + // calculate prediction performance on test set + predy = Calloc(long, test_folds[f]->n); + gensvm_predict_labels(test_folds[f], model, predy); + performance = gensvm_prediction_perf(test_folds[f], predy); + total_perf += performance * test_folds[f]->n; + + free(predy); + } + + total_perf /= ((double) n_total); + + // reset the output stream + GENSVM_OUTPUT_FILE = fid; + + return total_perf; +} |
