diff options
Diffstat (limited to 'src/GenSVMtraintest.c')
| -rw-r--r-- | src/GenSVMtraintest.c | 45 |
1 files changed, 10 insertions, 35 deletions
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 47cc900..73c0ce9 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -12,8 +12,7 @@ #include "gensvm_cmdarg.h" #include "gensvm_io.h" -#include "gensvm_init.h" -#include "gensvm_optimize.h" +#include "gensvm_train.h" #include "gensvm_pred.h" #define MINARGS 2 @@ -54,10 +53,8 @@ void exit_with_help() exit(0); } - int main(int argc, char **argv) { - bool with_test = false; long i, *predy = NULL; double performance; @@ -68,6 +65,7 @@ int main(int argc, char **argv) *prediction_outputfile = NULL; struct GenModel *model = gensvm_init_model(); + struct GenModel *seed_model = NULL; struct GenData *traindata = gensvm_init_data(); struct GenData *testdata = gensvm_init_data(); @@ -77,53 +75,29 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); - if (testing_inputfile != NULL) - with_test = true; // read data from files gensvm_read_data(traindata, training_inputfile); - if (with_test) - gensvm_read_data(testdata, testing_inputfile); - - // copy dataset parameters to model - model->n = traindata->n; - model->m = traindata->m; - model->K = traindata->K; model->data_file = training_inputfile; - // allocate model - gensvm_allocate_model(model); - - // run pre/post processing in case of kernels - gensvm_kernel_preprocess(model, traindata); - if (with_test) - gensvm_kernel_postprocess(model, traindata, testdata); - - // reallocate model in case of kernel dimension reduction - gensvm_reallocate_model(model, traindata->n, traindata->r); - - // initialize weights - gensvm_initialize_weights(traindata, model); - // seed the random number generator srand(time(NULL)); // load a seed model from file if it is specified if (gensvm_check_argv_eq(argc, argv, "-s")) { - struct GenModel *seed_model = gensvm_init_model(); + seed_model = gensvm_init_model(); gensvm_read_model(seed_model, model_inputfile); - gensvm_init_V(seed_model, model, traindata); - gensvm_free_model(seed_model); - } else { - gensvm_init_V(NULL, model, traindata); } - // start training - gensvm_optimize(model, traindata); + // train the GenSVM model + gensvm_train(model, traindata, seed_model); // if we also have a test set, predict labels and write to predictions // to an output file if specified - if (with_test) { + if (testing_inputfile != NULL) { + gensvm_read_data(testdata, testing_inputfile); + gensvm_kernel_postprocess(model, traindata, testdata); + // predict labels predy = Calloc(long, testdata->n); gensvm_predict_labels(testdata, model, predy); @@ -154,6 +128,7 @@ int main(int argc, char **argv) // free everything gensvm_free_model(model); + gensvm_free_model(seed_model); gensvm_free_data(traindata); gensvm_free_data(testdata); free(training_inputfile); |
