aboutsummaryrefslogtreecommitdiff
path: root/src/trainMSVMMaj.c
blob: 5a403be969b70207b2c94c36a8c714ca1ef63645 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include "libMSVMMaj.h"

#define MINARGS 2

void print_null(const char *s) {}
void exit_with_help();
void parse_command_line(int argc, char **argv, struct Model *model, 
		char *input_filename, char *output_filename, char *model_filename);

void exit_with_help()
{
	printf("This is MSVMMaj, version %1.1f\n\n", VERSION);
	printf("Usage: trainMSVMMaj [options] training_data_file\n");
	printf("Options:\n");
	printf("-c folds : perform cross validation with given number of folds\n");
	printf("-e epsilon : set the value of the stopping criterion\n");
	printf("-h | -help : print this help.\n");
	printf("-k kappa : set the value of kappa used in the Huber hinge\n");
	printf("-l lambda : set the value of lambda (lambda > 0)\n");
	printf("-m model_file : use previous model as seed for W and t\n");
	printf("-o output_file : write output to file\n");
	printf("-p p-value : set the value of p in the lp norm (1.0 <= p <= 2.0)\n");
	printf("-q : quiet mode (no output)\n");
	printf("-r rho : choose the weigth specification (1 = unit, 2 = group)\n");

	exit(0);
}

/*
	Main
*/
int main(int argc, char **argv)
{
	char input_filename[MAX_LINE_LENGTH];
	char model_filename[MAX_LINE_LENGTH];
	char output_filename[MAX_LINE_LENGTH];

	struct Model *model = Malloc(struct Model, 1);
	struct Data *data = Malloc(struct Data, 1);

	if (argc < MINARGS || check_argv(argc, argv, "-help") || check_argv_eq(argc, argv, "-h") ) 
		exit_with_help();
	parse_command_line(argc, argv, model, input_filename, output_filename, model_filename);

	// read data file
	read_data(data, input_filename);

	// copy dataset parameters to model	
	model->n = data->n;
	model->m = data->m;
	model->K = data->K;
	model->data_file = input_filename;

	// allocate model and initialize weights
	allocate_model(model);
	initialize_weights(data, model);

	if (check_argv_eq(argc, argv, "-m")) {
		struct Model *seed_model = Malloc(struct Model, 1);
		read_model(seed_model, model_filename);
		seed_model_V(seed_model, model);
		free_model(seed_model);
	} else {
		seed_model_V(NULL, model);
	}

	// start training
	main_loop(model, data);

	// write_model to file
	if (check_argv_eq(argc, argv, "-o")) {
		write_model(model, output_filename);
		info("Output written to %s\n", output_filename);
	}

	// free model and data
	free_model(model);
	free_data(data);
	
	return 0;
}

void parse_command_line(int argc, char **argv, struct Model *model, 
		char *input_filename, char *output_filename, char *model_filename)
{
	int i;
	void (*print_func)(const char*) = NULL;

	// default values
	model->p = 1.0;
	model->lambda = pow(2, -8.0);
	model->epsilon = 1e-6;
	model->kappa = 0.0;
	model->weight_idx = 1;

	// parse options
	for (i=1; i<argc; i++) {
		if (argv[i][0] != '-') break;
		if (++i>=argc) {
			exit_with_help();
		}
		switch (argv[i-1][1]) {
			case 'e':
				model->epsilon = atof(argv[i]);
				break;
			case 'k':
				model->kappa = atof(argv[i]);
				break;
			case 'l':
				model->lambda = atof(argv[i]);
				break;
			case 'm':
				strcpy(model_filename, argv[i]);
				break;
			case 'o':
				strcpy(output_filename, argv[i]);
				break;
			case 'p':
				model->p = atof(argv[i]);
				break;
			case 'r':
				model->weight_idx = atoi(argv[i]);
				break;
			case 'q':
				print_func = &print_null;
				i--;
				break;
			default:
				fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]);
				exit_with_help();
		}
	}

	// set print function
	set_print_string_function(print_func);

	// read input filename
	if (i >= argc)
		exit_with_help();

	strcpy(input_filename, argv[i]);
}