diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2016-05-16 21:41:27 +0200 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2016-05-16 21:41:27 +0200 |
| commit | 2584591266e434140a94d818402745f34fb3acf3 (patch) | |
| tree | 2faeb4f27b6b4101149e608657d2d018249e73c1 | |
| parent | major refactor of the code (diff) | |
| download | gensvm-2584591266e434140a94d818402745f34fb3acf3.tar.gz gensvm-2584591266e434140a94d818402745f34fb3acf3.zip | |
create a single training function for easy external access
| -rw-r--r-- | Makefile | 6 | ||||
| -rw-r--r-- | include/gensvm_optimize.h | 4 | ||||
| -rw-r--r-- | include/gensvm_train.h | 21 | ||||
| -rw-r--r-- | src/GenSVMtraintest.c | 45 | ||||
| -rw-r--r-- | src/gensvm_base.c | 3 | ||||
| -rw-r--r-- | src/gensvm_optimize.c | 2 | ||||
| -rw-r--r-- | src/gensvm_train.c | 51 |
7 files changed, 92 insertions, 40 deletions
@@ -47,7 +47,8 @@ lib/libgensvm.a: \ src/gensvm_strutil.o \ src/gensvm_sv.o \ src/gensvm_task.o \ - src/gensvm_timer.o + src/gensvm_timer.o \ + src/gensvm_train.o @ar rcs lib/libgensvm.a \ src/gensvm_base.o \ src/gensvm_cmdarg.o \ @@ -67,7 +68,8 @@ lib/libgensvm.a: \ src/gensvm_strutil.o \ src/gensvm_sv.o \ src/gensvm_task.o \ - src/gensvm_timer.o + src/gensvm_timer.o \ + src/gensvm_train.o @echo libgensvm.a... gensvm: src/GenSVMtraintest.c lib/libgensvm.a diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h index 7a23bdb..1041b72 100644 --- a/include/gensvm_optimize.h +++ b/include/gensvm_optimize.h @@ -10,8 +10,8 @@ * */ -#ifndef GENSVM_TRAIN_H -#define GENSVM_TRAIN_H +#ifndef GENSVM_OPTIMIZE_H +#define GENSVM_OPTIMIZE_H #include "gensvm_sv.h" #include "gensvm_print.h" diff --git a/include/gensvm_train.h b/include/gensvm_train.h new file mode 100644 index 0000000..c410cef --- /dev/null +++ b/include/gensvm_train.h @@ -0,0 +1,21 @@ +/** + * @file gensvm_train.h + * @author Gertjan van den Burg + * @date May, 2016 + * @brief Header file for gensvm_train.c + * + */ + +#ifndef GENSVM_TRAIN_H +#define GENSVM_TRAIN_H + +// includes +#include "gensvm_init.h" +#include "gensvm_kernel.h" +#include "gensvm_optimize.h" + +// function declarations +void gensvm_train(struct GenModel *model, struct GenData *data, + struct GenModel *seed_model); + +#endif diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 47cc900..73c0ce9 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -12,8 +12,7 @@ #include "gensvm_cmdarg.h" #include "gensvm_io.h" -#include "gensvm_init.h" -#include "gensvm_optimize.h" +#include "gensvm_train.h" #include "gensvm_pred.h" #define MINARGS 2 @@ -54,10 +53,8 @@ void exit_with_help() exit(0); } - int main(int argc, char **argv) { - bool with_test = false; long i, *predy = NULL; double performance; @@ -68,6 +65,7 @@ int main(int argc, char **argv) *prediction_outputfile = NULL; struct GenModel *model = gensvm_init_model(); + struct GenModel *seed_model = NULL; struct GenData *traindata = gensvm_init_data(); struct GenData *testdata = gensvm_init_data(); @@ -77,53 +75,29 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); - if (testing_inputfile != NULL) - with_test = true; // read data from files gensvm_read_data(traindata, training_inputfile); - if (with_test) - gensvm_read_data(testdata, testing_inputfile); - - // copy dataset parameters to model - model->n = traindata->n; - model->m = traindata->m; - model->K = traindata->K; model->data_file = training_inputfile; - // allocate model - gensvm_allocate_model(model); - - // run pre/post processing in case of kernels - gensvm_kernel_preprocess(model, traindata); - if (with_test) - gensvm_kernel_postprocess(model, traindata, testdata); - - // reallocate model in case of kernel dimension reduction - gensvm_reallocate_model(model, traindata->n, traindata->r); - - // initialize weights - gensvm_initialize_weights(traindata, model); - // seed the random number generator srand(time(NULL)); // load a seed model from file if it is specified if (gensvm_check_argv_eq(argc, argv, "-s")) { - struct GenModel *seed_model = gensvm_init_model(); + seed_model = gensvm_init_model(); gensvm_read_model(seed_model, model_inputfile); - gensvm_init_V(seed_model, model, traindata); - gensvm_free_model(seed_model); - } else { - gensvm_init_V(NULL, model, traindata); } - // start training - gensvm_optimize(model, traindata); + // train the GenSVM model + gensvm_train(model, traindata, seed_model); // if we also have a test set, predict labels and write to predictions // to an output file if specified - if (with_test) { + if (testing_inputfile != NULL) { + gensvm_read_data(testdata, testing_inputfile); + gensvm_kernel_postprocess(model, traindata, testdata); + // predict labels predy = Calloc(long, testdata->n); gensvm_predict_labels(testdata, model, predy); @@ -154,6 +128,7 @@ int main(int argc, char **argv) // free everything gensvm_free_model(model); + gensvm_free_model(seed_model); gensvm_free_data(traindata); gensvm_free_data(testdata); free(training_inputfile); diff --git a/src/gensvm_base.c b/src/gensvm_base.c index eddef5c..568f19a 100644 --- a/src/gensvm_base.c +++ b/src/gensvm_base.c @@ -194,6 +194,9 @@ void gensvm_reallocate_model(struct GenModel *model, long n, long m) */ void gensvm_free_model(struct GenModel *model) { + if (model == NULL) + return; + free(model->W); free(model->t); free(model->V); diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c index 70b3620..464815d 100644 --- a/src/gensvm_optimize.c +++ b/src/gensvm_optimize.c @@ -1,5 +1,5 @@ /** - * @file gensvm_train.c + * @file gensvm_optimize.c * @author Gertjan van den Burg * @date August 9, 2013 * @brief Main functions for training the GenSVM solution. diff --git a/src/gensvm_train.c b/src/gensvm_train.c new file mode 100644 index 0000000..4c0f332 --- /dev/null +++ b/src/gensvm_train.c @@ -0,0 +1,51 @@ +/** + * @file gensvm_train.c + * @author Gertjan van den Burg + * @date May, 2016 + * @brief Main function for training a GenSVM model. + * + */ + +#include "gensvm_train.h" + +/** + * @brief Utility function for training a GenSVM model + * + * @details + * This function organizes model allocation, kernel preprocessing, instance + * weight initialization, and model training. It is the function that should + * be used for training a single GenSVM model. Note that optionally a seed + * model can be passed to the function to seed the V matrix with. When no such + * model is used this parameter should be set to NULL. + * + * @param[in] model a GenModel instance + * @param[in] data a GenData instance with the training data + * @param[in] seed_model an optional GenModel to seed the V matrix + * + */ +void gensvm_train(struct GenModel *model, struct GenData *data, + struct GenModel *seed_model) +{ + // copy dataset parameters to model + model->n = data->n; + model->m = data->m; + model->K = data->K; + + // initialize the V matrix (potentially with a seed model) + gensvm_init_V(seed_model, model, data); + + // allocate model + gensvm_allocate_model(model); + + // preprocess kernel + gensvm_kernel_preprocess(model, data); + + // reallocate model for kernels + gensvm_reallocate_model(model, data->n, data->r); + + // initialize weights + gensvm_initialize_weights(data, model); + + // start training + gensvm_optimize(model, data); +} |
