diff options
Diffstat (limited to 'src/GenSVMgrid.c')
| -rw-r--r-- | src/GenSVMgrid.c | 108 |
1 files changed, 93 insertions, 15 deletions
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 782200f..ab209cc 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -39,9 +39,10 @@ #include "gensvm_checks.h" #include "gensvm_cmdarg.h" +#include "gensvm_consistency.h" #include "gensvm_io.h" #include "gensvm_gridsearch.h" -#include "gensvm_consistency.h" +#include "gensvm_train.h" /** * Minimal number of command line arguments @@ -53,7 +54,8 @@ extern FILE *GENSVM_ERROR_FILE; // function declarations void exit_with_help(char **argv); -long parse_command_line(int argc, char **argv, char *input_filename); +long parse_command_line(int argc, char **argv, char *input_filename, + char **prediction_outputfile); void read_grid_from_file(char *input_filename, struct GenGrid *grid); /** @@ -75,6 +77,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); + printf("-o prediction_output : write predictions of test data to " + "file (uses stdout if not provided)\n"); printf("-q : quiet mode (no output, not even errors!)\n"); printf("-x : data files are in LibSVM/SVMlight format\n"); printf("-z : seed for the random number generator\n"); @@ -102,9 +106,11 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + int i, best_ID = -1; long seed; bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; + char *prediction_outputfile = NULL; struct GenGrid *grid = gensvm_init_grid(); struct GenData *train_data = gensvm_init_data(); @@ -114,7 +120,8 @@ int main(int argc, char **argv) if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help") || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); - seed = parse_command_line(argc, argv, input_filename); + seed = parse_command_line(argc, argv, input_filename, + &prediction_outputfile); libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); @@ -126,6 +133,18 @@ int main(int argc, char **argv) else gensvm_read_data(train_data, grid->train_data_file); + // Read the test data if present + if (grid->test_data_file != NULL) { + if (libsvm_format) + gensvm_read_data_libsvm(test_data, + grid->test_data_file); + else + gensvm_read_data(test_data, grid->test_data_file); + } else { + gensvm_free_data(test_data); + test_data = NULL; + } + // check labels of training data gensvm_check_outcome_contiguous(train_data); if (!gensvm_check_outcome_contiguous(train_data)) { @@ -144,15 +163,6 @@ int main(int argc, char **argv) gensvm_free_sparse(train_data->spZ); } - if (grid->traintype == TT) { - err("[GenSVM Warning]: Using test datasets in a grid search " - "is not yet supported in GenSVM.\n" - " The test dataset will be " - "ignored during training.\n"); - //note("Reading data from %s\n", grid->test_data_file); - //gensvm_read_data(test_data, grid->test_data_file); - } - note("Creating queue\n"); gensvm_fill_queue(grid, q, train_data, test_data); @@ -163,7 +173,70 @@ int main(int argc, char **argv) note("Training finished\n"); if (grid->repeats > 0) { - gensvm_consistency_repeats(q, grid->repeats, grid->percentile); + best_ID = gensvm_consistency_repeats(q, grid->repeats, + grid->percentile); + } else { + double maxperf = -1; + for (i=0; i<q->N; i++) { + if (q->tasks[i]->performance > maxperf) { + maxperf = q->tasks[i]->performance; + best_ID = q->tasks[i]->ID; + } + } + } + + // If we have test data, train best model on training data and predict + // test data + if (test_data) { + struct GenTask *best_task = NULL; + struct GenModel *best_model = NULL; + long *predy = NULL; + double performance = -1; + + for (i=0; i<q->N; i++) + if (q->tasks[i]->ID == best_ID) + best_task = q->tasks[i]; + + best_model = gensvm_init_model(); + gensvm_task_to_model(best_task, best_model); + + gensvm_train(best_model, train_data, NULL); + + // check if we are sparse and want nonlinearity + if (test_data->Z == NULL && + best_model->kerneltype != K_LINEAR) { + err("[GenSVM Warning]: Sparse matrices with nonlinear " + "kernels are not yet supported. Dense " + "matrices will be used.\n"); + test_data->Z = gensvm_sparse_to_dense(test_data->spZ); + gensvm_free_sparse(test_data->spZ); + } + + gensvm_kernel_postprocess(best_model, train_data, test_data); + + // predict labels + predy = Calloc(long, test_data->n); + gensvm_predict_labels(test_data, best_model, predy); + + if (test_data->y != NULL) { + performance = gensvm_prediction_perf(test_data, predy); + note("Predictive performance: %3.2f%%\n", performance); + } + + // if output file is specified, write predictions to it + if (gensvm_check_argv_eq(argc, argv, "-o")) { + gensvm_write_predictions(test_data, predy, + prediction_outputfile); + note("Prediction written to: %s\n", + prediction_outputfile); + } else { + for (i=0; i<test_data->n; i++) + printf("%li ", predy[i]); + printf("\n"); + } + + gensvm_free_model(best_model); + free(predy); } gensvm_free_queue(q); @@ -191,7 +264,8 @@ int main(int argc, char **argv) * @returns seed for the RNG * */ -long parse_command_line(int argc, char **argv, char *input_filename) +long parse_command_line(int argc, char **argv, char *input_filename, + char **prediction_outputfile) { long seed = time(NULL); int i; @@ -204,6 +278,11 @@ long parse_command_line(int argc, char **argv, char *input_filename) if (++i>=argc) exit_with_help(argv); switch (argv[i-1][1]) { + case 'o': + (*prediction_outputfile) = Malloc(char, + strlen(argv[i]) + 1); + strcpy((*prediction_outputfile), argv[i]); + break; case 'q': GENSVM_OUTPUT_FILE = NULL; GENSVM_ERROR_FILE = NULL; @@ -304,7 +383,6 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) grid->test_data_file = Calloc(char, GENSVM_MAX_LINE_LENGTH); strcpy(grid->test_data_file, test_filename); - grid->traintype = TT; } else if (str_startswith(buffer, "p:")) { nr = all_doubles_str(buffer, 2, params); grid->ps = Calloc(double, nr); |
