From 6d064658f8ae7ca0f42fef6dcc7f896144e9637b Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Fri, 18 Oct 2013 15:48:59 +0200 Subject: restart using git --- src/trainMSVMMaj.c | 62 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 27 deletions(-) (limited to 'src/trainMSVMMaj.c') diff --git a/src/trainMSVMMaj.c b/src/trainMSVMMaj.c index 5a403be..ebcf36c 100644 --- a/src/trainMSVMMaj.c +++ b/src/trainMSVMMaj.c @@ -1,10 +1,18 @@ +#include +#include + #include "libMSVMMaj.h" +#include "msvmmaj_train.h" +#include "util.h" +#include "MSVMMaj.h" #define MINARGS 2 +extern FILE *MSVMMAJ_OUTPUT_FILE; + void print_null(const char *s) {} void exit_with_help(); -void parse_command_line(int argc, char **argv, struct Model *model, +void parse_command_line(int argc, char **argv, struct MajModel *model, char *input_filename, char *output_filename, char *model_filename); void exit_with_help() @@ -12,7 +20,6 @@ void exit_with_help() printf("This is MSVMMaj, version %1.1f\n\n", VERSION); printf("Usage: trainMSVMMaj [options] training_data_file\n"); printf("Options:\n"); - printf("-c folds : perform cross validation with given number of folds\n"); printf("-e epsilon : set the value of the stopping criterion\n"); printf("-h | -help : print this help.\n"); printf("-k kappa : set the value of kappa used in the Huber hinge\n"); @@ -35,15 +42,16 @@ int main(int argc, char **argv) char model_filename[MAX_LINE_LENGTH]; char output_filename[MAX_LINE_LENGTH]; - struct Model *model = Malloc(struct Model, 1); - struct Data *data = Malloc(struct Data, 1); + struct MajModel *model = Malloc(struct MajModel, 1); + struct MajData *data = Malloc(struct MajData, 1); - if (argc < MINARGS || check_argv(argc, argv, "-help") || check_argv_eq(argc, argv, "-h") ) + if (argc < MINARGS || msvmmaj_check_argv(argc, argv, "-help") + || msvmmaj_check_argv_eq(argc, argv, "-h") ) exit_with_help(); parse_command_line(argc, argv, model, input_filename, output_filename, model_filename); // read data file - read_data(data, input_filename); + msvmmaj_read_data(data, input_filename); // copy dataset parameters to model model->n = data->n; @@ -52,39 +60,40 @@ int main(int argc, char **argv) model->data_file = input_filename; // allocate model and initialize weights - allocate_model(model); - initialize_weights(data, model); - - if (check_argv_eq(argc, argv, "-m")) { - struct Model *seed_model = Malloc(struct Model, 1); - read_model(seed_model, model_filename); - seed_model_V(seed_model, model); - free_model(seed_model); + msvmmaj_allocate_model(model); + msvmmaj_initialize_weights(data, model); + + srand(time(NULL)); + + if (msvmmaj_check_argv_eq(argc, argv, "-m")) { + struct MajModel *seed_model = Malloc(struct MajModel, 1); + msvmmaj_read_model(seed_model, model_filename); + msvmmaj_seed_model_V(seed_model, model); + msvmmaj_free_model(seed_model); } else { - seed_model_V(NULL, model); + msvmmaj_seed_model_V(NULL, model); } // start training - main_loop(model, data); + msvmmaj_optimize(model, data); // write_model to file - if (check_argv_eq(argc, argv, "-o")) { - write_model(model, output_filename); - info("Output written to %s\n", output_filename); + if (msvmmaj_check_argv_eq(argc, argv, "-o")) { + msvmmaj_write_model(model, output_filename); + note("Output written to %s\n", output_filename); } // free model and data - free_model(model); - free_data(data); + msvmmaj_free_model(model); + msvmmaj_free_data(data); return 0; } -void parse_command_line(int argc, char **argv, struct Model *model, +void parse_command_line(int argc, char **argv, struct MajModel *model, char *input_filename, char *output_filename, char *model_filename) { int i; - void (*print_func)(const char*) = NULL; // default values model->p = 1.0; @@ -92,6 +101,8 @@ void parse_command_line(int argc, char **argv, struct Model *model, model->epsilon = 1e-6; model->kappa = 0.0; model->weight_idx = 1; + + MSVMMAJ_OUTPUT_FILE = stdout; // parse options for (i=1; iweight_idx = atoi(argv[i]); break; case 'q': - print_func = &print_null; + MSVMMAJ_OUTPUT_FILE = NULL; i--; break; default: @@ -131,9 +142,6 @@ void parse_command_line(int argc, char **argv, struct Model *model, } } - // set print function - set_print_string_function(print_func); - // read input filename if (i >= argc) exit_with_help(); -- cgit v1.2.3