aboutsummaryrefslogtreecommitdiff
path: root/src/trainMSVMMaj.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/trainMSVMMaj.c')
-rw-r--r--src/trainMSVMMaj.c62
1 files changed, 35 insertions, 27 deletions
diff --git a/src/trainMSVMMaj.c b/src/trainMSVMMaj.c
index 5a403be..ebcf36c 100644
--- a/src/trainMSVMMaj.c
+++ b/src/trainMSVMMaj.c
@@ -1,10 +1,18 @@
+#include <time.h>
+#include <math.h>
+
#include "libMSVMMaj.h"
+#include "msvmmaj_train.h"
+#include "util.h"
+#include "MSVMMaj.h"
#define MINARGS 2
+extern FILE *MSVMMAJ_OUTPUT_FILE;
+
void print_null(const char *s) {}
void exit_with_help();
-void parse_command_line(int argc, char **argv, struct Model *model,
+void parse_command_line(int argc, char **argv, struct MajModel *model,
char *input_filename, char *output_filename, char *model_filename);
void exit_with_help()
@@ -12,7 +20,6 @@ void exit_with_help()
printf("This is MSVMMaj, version %1.1f\n\n", VERSION);
printf("Usage: trainMSVMMaj [options] training_data_file\n");
printf("Options:\n");
- printf("-c folds : perform cross validation with given number of folds\n");
printf("-e epsilon : set the value of the stopping criterion\n");
printf("-h | -help : print this help.\n");
printf("-k kappa : set the value of kappa used in the Huber hinge\n");
@@ -35,15 +42,16 @@ int main(int argc, char **argv)
char model_filename[MAX_LINE_LENGTH];
char output_filename[MAX_LINE_LENGTH];
- struct Model *model = Malloc(struct Model, 1);
- struct Data *data = Malloc(struct Data, 1);
+ struct MajModel *model = Malloc(struct MajModel, 1);
+ struct MajData *data = Malloc(struct MajData, 1);
- if (argc < MINARGS || check_argv(argc, argv, "-help") || check_argv_eq(argc, argv, "-h") )
+ if (argc < MINARGS || msvmmaj_check_argv(argc, argv, "-help")
+ || msvmmaj_check_argv_eq(argc, argv, "-h") )
exit_with_help();
parse_command_line(argc, argv, model, input_filename, output_filename, model_filename);
// read data file
- read_data(data, input_filename);
+ msvmmaj_read_data(data, input_filename);
// copy dataset parameters to model
model->n = data->n;
@@ -52,39 +60,40 @@ int main(int argc, char **argv)
model->data_file = input_filename;
// allocate model and initialize weights
- allocate_model(model);
- initialize_weights(data, model);
-
- if (check_argv_eq(argc, argv, "-m")) {
- struct Model *seed_model = Malloc(struct Model, 1);
- read_model(seed_model, model_filename);
- seed_model_V(seed_model, model);
- free_model(seed_model);
+ msvmmaj_allocate_model(model);
+ msvmmaj_initialize_weights(data, model);
+
+ srand(time(NULL));
+
+ if (msvmmaj_check_argv_eq(argc, argv, "-m")) {
+ struct MajModel *seed_model = Malloc(struct MajModel, 1);
+ msvmmaj_read_model(seed_model, model_filename);
+ msvmmaj_seed_model_V(seed_model, model);
+ msvmmaj_free_model(seed_model);
} else {
- seed_model_V(NULL, model);
+ msvmmaj_seed_model_V(NULL, model);
}
// start training
- main_loop(model, data);
+ msvmmaj_optimize(model, data);
// write_model to file
- if (check_argv_eq(argc, argv, "-o")) {
- write_model(model, output_filename);
- info("Output written to %s\n", output_filename);
+ if (msvmmaj_check_argv_eq(argc, argv, "-o")) {
+ msvmmaj_write_model(model, output_filename);
+ note("Output written to %s\n", output_filename);
}
// free model and data
- free_model(model);
- free_data(data);
+ msvmmaj_free_model(model);
+ msvmmaj_free_data(data);
return 0;
}
-void parse_command_line(int argc, char **argv, struct Model *model,
+void parse_command_line(int argc, char **argv, struct MajModel *model,
char *input_filename, char *output_filename, char *model_filename)
{
int i;
- void (*print_func)(const char*) = NULL;
// default values
model->p = 1.0;
@@ -92,6 +101,8 @@ void parse_command_line(int argc, char **argv, struct Model *model,
model->epsilon = 1e-6;
model->kappa = 0.0;
model->weight_idx = 1;
+
+ MSVMMAJ_OUTPUT_FILE = stdout;
// parse options
for (i=1; i<argc; i++) {
@@ -122,7 +133,7 @@ void parse_command_line(int argc, char **argv, struct Model *model,
model->weight_idx = atoi(argv[i]);
break;
case 'q':
- print_func = &print_null;
+ MSVMMAJ_OUTPUT_FILE = NULL;
i--;
break;
default:
@@ -131,9 +142,6 @@ void parse_command_line(int argc, char **argv, struct Model *model,
}
}
- // set print function
- set_print_string_function(print_func);
-
// read input filename
if (i >= argc)
exit_with_help();