diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 19:31:26 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 19:31:26 +0100 |
| commit | 527e1e0162a287d11b92be48eb22c5f98d6e59b8 (patch) | |
| tree | 8df244368fcbdf77e2a72860e879bda26c0f64d9 | |
| parent | Major bugfix for nonlinear GenSVM (diff) | |
| download | gensvm-527e1e0162a287d11b92be48eb22c5f98d6e59b8.tar.gz gensvm-527e1e0162a287d11b92be48eb22c5f98d6e59b8.zip | |
Add support for predicting after grid search
With this commit the gensvm_grid executable can now compute
predictions with the best model found during the grid search.
The test dataset is supplied through the training file, and
a command line flag is added to support saving the predictions
in an output file.
| -rw-r--r-- | include/gensvm_consistency.h | 2 | ||||
| -rw-r--r-- | src/GenSVMgrid.c | 108 | ||||
| -rw-r--r-- | src/gensvm_consistency.c | 15 |
3 files changed, 106 insertions, 19 deletions
diff --git a/include/gensvm_consistency.h b/include/gensvm_consistency.h index 61e699d..d4572ee 100644 --- a/include/gensvm_consistency.h +++ b/include/gensvm_consistency.h @@ -37,7 +37,7 @@ // function declarations struct GenQueue *gensvm_top_queue(struct GenQueue *q, double percentile); int gensvm_dsort(const void *elem1, const void *elem2); -void gensvm_consistency_repeats(struct GenQueue *q, long repeats, +int gensvm_consistency_repeats(struct GenQueue *q, long repeats, double percentile); double gensvm_percentile(double *values, long N, double p); diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 782200f..ab209cc 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -39,9 +39,10 @@ #include "gensvm_checks.h" #include "gensvm_cmdarg.h" +#include "gensvm_consistency.h" #include "gensvm_io.h" #include "gensvm_gridsearch.h" -#include "gensvm_consistency.h" +#include "gensvm_train.h" /** * Minimal number of command line arguments @@ -53,7 +54,8 @@ extern FILE *GENSVM_ERROR_FILE; // function declarations void exit_with_help(char **argv); -long parse_command_line(int argc, char **argv, char *input_filename); +long parse_command_line(int argc, char **argv, char *input_filename, + char **prediction_outputfile); void read_grid_from_file(char *input_filename, struct GenGrid *grid); /** @@ -75,6 +77,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); + printf("-o prediction_output : write predictions of test data to " + "file (uses stdout if not provided)\n"); printf("-q : quiet mode (no output, not even errors!)\n"); printf("-x : data files are in LibSVM/SVMlight format\n"); printf("-z : seed for the random number generator\n"); @@ -102,9 +106,11 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + int i, best_ID = -1; long seed; bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; + char *prediction_outputfile = NULL; struct GenGrid *grid = gensvm_init_grid(); struct GenData *train_data = gensvm_init_data(); @@ -114,7 +120,8 @@ int main(int argc, char **argv) if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help") || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); - seed = parse_command_line(argc, argv, input_filename); + seed = parse_command_line(argc, argv, input_filename, + &prediction_outputfile); libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); @@ -126,6 +133,18 @@ int main(int argc, char **argv) else gensvm_read_data(train_data, grid->train_data_file); + // Read the test data if present + if (grid->test_data_file != NULL) { + if (libsvm_format) + gensvm_read_data_libsvm(test_data, + grid->test_data_file); + else + gensvm_read_data(test_data, grid->test_data_file); + } else { + gensvm_free_data(test_data); + test_data = NULL; + } + // check labels of training data gensvm_check_outcome_contiguous(train_data); if (!gensvm_check_outcome_contiguous(train_data)) { @@ -144,15 +163,6 @@ int main(int argc, char **argv) gensvm_free_sparse(train_data->spZ); } - if (grid->traintype == TT) { - err("[GenSVM Warning]: Using test datasets in a grid search " - "is not yet supported in GenSVM.\n" - " The test dataset will be " - "ignored during training.\n"); - //note("Reading data from %s\n", grid->test_data_file); - //gensvm_read_data(test_data, grid->test_data_file); - } - note("Creating queue\n"); gensvm_fill_queue(grid, q, train_data, test_data); @@ -163,7 +173,70 @@ int main(int argc, char **argv) note("Training finished\n"); if (grid->repeats > 0) { - gensvm_consistency_repeats(q, grid->repeats, grid->percentile); + best_ID = gensvm_consistency_repeats(q, grid->repeats, + grid->percentile); + } else { + double maxperf = -1; + for (i=0; i<q->N; i++) { + if (q->tasks[i]->performance > maxperf) { + maxperf = q->tasks[i]->performance; + best_ID = q->tasks[i]->ID; + } + } + } + + // If we have test data, train best model on training data and predict + // test data + if (test_data) { + struct GenTask *best_task = NULL; + struct GenModel *best_model = NULL; + long *predy = NULL; + double performance = -1; + + for (i=0; i<q->N; i++) + if (q->tasks[i]->ID == best_ID) + best_task = q->tasks[i]; + + best_model = gensvm_init_model(); + gensvm_task_to_model(best_task, best_model); + + gensvm_train(best_model, train_data, NULL); + + // check if we are sparse and want nonlinearity + if (test_data->Z == NULL && + best_model->kerneltype != K_LINEAR) { + err("[GenSVM Warning]: Sparse matrices with nonlinear " + "kernels are not yet supported. Dense " + "matrices will be used.\n"); + test_data->Z = gensvm_sparse_to_dense(test_data->spZ); + gensvm_free_sparse(test_data->spZ); + } + + gensvm_kernel_postprocess(best_model, train_data, test_data); + + // predict labels + predy = Calloc(long, test_data->n); + gensvm_predict_labels(test_data, best_model, predy); + + if (test_data->y != NULL) { + performance = gensvm_prediction_perf(test_data, predy); + note("Predictive performance: %3.2f%%\n", performance); + } + + // if output file is specified, write predictions to it + if (gensvm_check_argv_eq(argc, argv, "-o")) { + gensvm_write_predictions(test_data, predy, + prediction_outputfile); + note("Prediction written to: %s\n", + prediction_outputfile); + } else { + for (i=0; i<test_data->n; i++) + printf("%li ", predy[i]); + printf("\n"); + } + + gensvm_free_model(best_model); + free(predy); } gensvm_free_queue(q); @@ -191,7 +264,8 @@ int main(int argc, char **argv) * @returns seed for the RNG * */ -long parse_command_line(int argc, char **argv, char *input_filename) +long parse_command_line(int argc, char **argv, char *input_filename, + char **prediction_outputfile) { long seed = time(NULL); int i; @@ -204,6 +278,11 @@ long parse_command_line(int argc, char **argv, char *input_filename) if (++i>=argc) exit_with_help(argv); switch (argv[i-1][1]) { + case 'o': + (*prediction_outputfile) = Malloc(char, + strlen(argv[i]) + 1); + strcpy((*prediction_outputfile), argv[i]); + break; case 'q': GENSVM_OUTPUT_FILE = NULL; GENSVM_ERROR_FILE = NULL; @@ -304,7 +383,6 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) grid->test_data_file = Calloc(char, GENSVM_MAX_LINE_LENGTH); strcpy(grid->test_data_file, test_filename); - grid->traintype = TT; } else if (str_startswith(buffer, "p:")) { nr = all_doubles_str(buffer, 2, params); grid->ps = Calloc(double, nr); diff --git a/src/gensvm_consistency.c b/src/gensvm_consistency.c index 07ff7a6..0aa33c0 100644 --- a/src/gensvm_consistency.c +++ b/src/gensvm_consistency.c @@ -124,8 +124,11 @@ struct GenQueue *gensvm_top_queue(struct GenQueue *q, double percentile) * configurations for consistency * @param[in] percentile percentile of performance to determine which * tasks to repeat + * + * @return ID of the best task + * */ -void gensvm_consistency_repeats(struct GenQueue *q, long repeats, +int gensvm_consistency_repeats(struct GenQueue *q, long repeats, double percentile) { bool breakout; @@ -229,11 +232,12 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats, "mean_perf\tstd_perf\ttime_perf\n"); p = 0.0; breakout = false; - while (breakout == false) { + int best_id = -1; + while (!breakout) { pi = gensvm_percentile(mean, N, (100.0-p)); pr = gensvm_percentile(std, N, p); pt = gensvm_percentile(time, N, p); - for (i=0; i<N; i++) + for (i=0; i<N; i++) { if ((pi - mean[i] < 0.0001) && (std[i] - pr < 0.0001) && (time[i] - pt < 0.0001)) { @@ -251,7 +255,10 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats, std[i], time[i]); breakout = true; + if (best_id == -1) + best_id = nq->tasks[i]->ID; } + } p += 1.0; } @@ -263,6 +270,8 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats, free(std); free(mean); free(time); + + return best_id; } /** |
