aboutsummaryrefslogtreecommitdiff
path: root/src/GenSVMpred.c
blob: 57680b1244ecc9ae170ac55170cc9888ebcf8472 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/*
 * 20140317:
 * THIS FUNCTION IS DEPRECATED, SINCE IT DOES NOT WORK WITH KERNELS.
 *
 */

/**
 * @file GenSVM_pred.c
 * @author Gertjan van den Burg
 * @date January, 2014
 * @brief Command line interface for predicting class labels
 *
 * @details
 * This is a command line program for predicting the class labels or
 * determining the predictive performance of a pre-determined model on a given
 * test dataset. The predictive performance can be written to the screen or
 * the predicted class labels can be written to a specified output file. This
 * is done using gensvm_write_predictions().
 *
 * The specified model file must follow the specification given in
 * gensvm_write_model().
 *
 * For usage information, see the program help function.
 *
 */

#include "gensvm.h"
#include "gensvm_init.h"
#include "gensvm_io.h"
#include "gensvm_pred.h"
#include "gensvm_util.h"

#define MINARGS 3

extern FILE *GENSVM_OUTPUT_FILE;

// function declarations
void exit_with_help();
void parse_command_line(int argc, char **argv,
		char *input_filename, char *output_filename,
		char *model_filename);

/**
 * @brief Help function
 */
void exit_with_help()
{
	printf("This is GenSVM, version %1.1f\n\n", VERSION);
	printf("Usage: predGenSVM [options] test_data_file model_file\n");
	printf("Options:\n");
	printf("-o output_file : write output to file\n");
	printf("-q : quiet mode (no output)\n");
	exit(0);
}

/**
 * @brief Main interface function for predGenSVM
 *
 * @details
 * Main interface for the command line program. A given model file is read and
 * a test dataset is initialized from the given data. The predictive
 * performance (hitrate) of the model on the test set is printed to the output
 * stream (default = stdout). If an output file is specified the predictions
 * are written to the file.
 *
 * @todo
 * Ensure that the program can read model files without class labels
 * specified. In that case no prediction accuracy is printed to the screen.
 *
 * @param[in] 	argc 	number of command line arguments
 * @param[in] 	argv 	array of command line arguments
 *
 */
int main(int argc, char **argv)
{
	long *predy;
	double performance;

	char input_filename[MAX_LINE_LENGTH];
	char model_filename[MAX_LINE_LENGTH];
	char output_filename[MAX_LINE_LENGTH];;

	if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help")
			|| gensvm_check_argv_eq(argc, argv, "-h") )
		exit_with_help();
	parse_command_line(argc, argv, input_filename, output_filename,
			model_filename);

	// read the data and model
	struct GenModel *model = gensvm_init_model();
	struct GenData *data = gensvm_init_data();
	gensvm_read_data(data, input_filename);
	gensvm_read_model(model, model_filename);

	// check if the number of attributes in data equals that in model
	if (data->m != model->m) {
		fprintf(stderr, "Error: number of attributes in data (%li) "
				"does not equal the number of attributes in "
				"model (%li)\n", data->m, model->m);
		exit(1);
	} else if (data->K != model->K) {
		fprintf(stderr, "Error: number of classes in data (%li) "
				"does not equal the number of classes in "
				"model (%li)\n", data->K, model->K);
		exit(1);
	}

	// predict labels and performance if test data has labels
	predy = Calloc(long, data->n);
	gensvm_predict_labels(data, model, predy);
	if (data->y != NULL) {
		performance = gensvm_prediction_perf(data, predy);
		note("Predictive performance: %3.2f%%\n", performance);
	}

	// if output file is specified, write predictions to it
	if (gensvm_check_argv_eq(argc, argv, "-o")) {
		gensvm_write_predictions(data, predy, output_filename);
		note("Predictions written to: %s\n", output_filename);
	}

	// free the model, data, and predictions
	gensvm_free_model(model);
	gensvm_free_data(data);
	free(predy);

	return 0;
}

/**
 * @brief Parse command line arguments
 *
 * @details
 * Read the data filename and model filename from the command line arguments.
 * If specified, also read the output filename. If the quiet flag is given,
 * set the global output stream to NULL. On error, exit_with_help().
 *
 * @param[in] 	argc 			number of command line arguments
 * @param[in] 	argv 			array of command line arguments
 * @param[in] 	input_filename 		pre-allocated array for the input
 * 					filename
 * @param[in] 	output_filename 	pre-allocated array for the output
 * 					filename
 * @param[in] 	model_filename 		pre-allocated array for the model
 * 					filename
 *
 */
void parse_command_line(int argc, char **argv, char *input_filename,
		char *output_filename, char *model_filename)
{
	int i;

	GENSVM_OUTPUT_FILE = stdout;

	for (i=1; i<argc; i++) {
		if (argv[i][0] != '-') break;
		if (++i >= argc)
			exit_with_help();
		switch (argv[i-1][1]) {
			case 'o':
				strcpy(output_filename, argv[i]);
				break;
			case 'q':
				GENSVM_OUTPUT_FILE = NULL;
				i--;
				break;
			default:
				fprintf(stderr, "Unknown option: -%c\n",
						argv[i-1][1]);
				exit_with_help();
		}
	}

	if (i >= argc)
		exit_with_help();

	strcpy(input_filename, argv[i]);
	i++;
	strcpy(model_filename, argv[i]);
}