aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-27 19:31:26 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-27 19:31:26 +0100
commit527e1e0162a287d11b92be48eb22c5f98d6e59b8 (patch)
tree8df244368fcbdf77e2a72860e879bda26c0f64d9
parentMajor bugfix for nonlinear GenSVM (diff)
downloadgensvm-527e1e0162a287d11b92be48eb22c5f98d6e59b8.tar.gz
gensvm-527e1e0162a287d11b92be48eb22c5f98d6e59b8.zip
Add support for predicting after grid search
With this commit the gensvm_grid executable can now compute predictions with the best model found during the grid search. The test dataset is supplied through the training file, and a command line flag is added to support saving the predictions in an output file.
-rw-r--r--include/gensvm_consistency.h2
-rw-r--r--src/GenSVMgrid.c108
-rw-r--r--src/gensvm_consistency.c15
3 files changed, 106 insertions, 19 deletions
diff --git a/include/gensvm_consistency.h b/include/gensvm_consistency.h
index 61e699d..d4572ee 100644
--- a/include/gensvm_consistency.h
+++ b/include/gensvm_consistency.h
@@ -37,7 +37,7 @@
// function declarations
struct GenQueue *gensvm_top_queue(struct GenQueue *q, double percentile);
int gensvm_dsort(const void *elem1, const void *elem2);
-void gensvm_consistency_repeats(struct GenQueue *q, long repeats,
+int gensvm_consistency_repeats(struct GenQueue *q, long repeats,
double percentile);
double gensvm_percentile(double *values, long N, double p);
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 782200f..ab209cc 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -39,9 +39,10 @@
#include "gensvm_checks.h"
#include "gensvm_cmdarg.h"
+#include "gensvm_consistency.h"
#include "gensvm_io.h"
#include "gensvm_gridsearch.h"
-#include "gensvm_consistency.h"
+#include "gensvm_train.h"
/**
* Minimal number of command line arguments
@@ -53,7 +54,8 @@ extern FILE *GENSVM_ERROR_FILE;
// function declarations
void exit_with_help(char **argv);
-long parse_command_line(int argc, char **argv, char *input_filename);
+long parse_command_line(int argc, char **argv, char *input_filename,
+ char **prediction_outputfile);
void read_grid_from_file(char *input_filename, struct GenGrid *grid);
/**
@@ -75,6 +77,8 @@ void exit_with_help(char **argv)
printf("Usage: %s [options] grid_file\n", argv[0]);
printf("Options:\n");
printf("-h | -help : print this help.\n");
+ printf("-o prediction_output : write predictions of test data to "
+ "file (uses stdout if not provided)\n");
printf("-q : quiet mode (no output, not even errors!)\n");
printf("-x : data files are in LibSVM/SVMlight format\n");
printf("-z : seed for the random number generator\n");
@@ -102,9 +106,11 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ int i, best_ID = -1;
long seed;
bool libsvm_format = false;
char input_filename[GENSVM_MAX_LINE_LENGTH];
+ char *prediction_outputfile = NULL;
struct GenGrid *grid = gensvm_init_grid();
struct GenData *train_data = gensvm_init_data();
@@ -114,7 +120,8 @@ int main(int argc, char **argv)
if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help")
|| gensvm_check_argv_eq(argc, argv, "-h") )
exit_with_help(argv);
- seed = parse_command_line(argc, argv, input_filename);
+ seed = parse_command_line(argc, argv, input_filename,
+ &prediction_outputfile);
libsvm_format = gensvm_check_argv(argc, argv, "-x");
note("Reading grid file\n");
@@ -126,6 +133,18 @@ int main(int argc, char **argv)
else
gensvm_read_data(train_data, grid->train_data_file);
+ // Read the test data if present
+ if (grid->test_data_file != NULL) {
+ if (libsvm_format)
+ gensvm_read_data_libsvm(test_data,
+ grid->test_data_file);
+ else
+ gensvm_read_data(test_data, grid->test_data_file);
+ } else {
+ gensvm_free_data(test_data);
+ test_data = NULL;
+ }
+
// check labels of training data
gensvm_check_outcome_contiguous(train_data);
if (!gensvm_check_outcome_contiguous(train_data)) {
@@ -144,15 +163,6 @@ int main(int argc, char **argv)
gensvm_free_sparse(train_data->spZ);
}
- if (grid->traintype == TT) {
- err("[GenSVM Warning]: Using test datasets in a grid search "
- "is not yet supported in GenSVM.\n"
- " The test dataset will be "
- "ignored during training.\n");
- //note("Reading data from %s\n", grid->test_data_file);
- //gensvm_read_data(test_data, grid->test_data_file);
- }
-
note("Creating queue\n");
gensvm_fill_queue(grid, q, train_data, test_data);
@@ -163,7 +173,70 @@ int main(int argc, char **argv)
note("Training finished\n");
if (grid->repeats > 0) {
- gensvm_consistency_repeats(q, grid->repeats, grid->percentile);
+ best_ID = gensvm_consistency_repeats(q, grid->repeats,
+ grid->percentile);
+ } else {
+ double maxperf = -1;
+ for (i=0; i<q->N; i++) {
+ if (q->tasks[i]->performance > maxperf) {
+ maxperf = q->tasks[i]->performance;
+ best_ID = q->tasks[i]->ID;
+ }
+ }
+ }
+
+ // If we have test data, train best model on training data and predict
+ // test data
+ if (test_data) {
+ struct GenTask *best_task = NULL;
+ struct GenModel *best_model = NULL;
+ long *predy = NULL;
+ double performance = -1;
+
+ for (i=0; i<q->N; i++)
+ if (q->tasks[i]->ID == best_ID)
+ best_task = q->tasks[i];
+
+ best_model = gensvm_init_model();
+ gensvm_task_to_model(best_task, best_model);
+
+ gensvm_train(best_model, train_data, NULL);
+
+ // check if we are sparse and want nonlinearity
+ if (test_data->Z == NULL &&
+ best_model->kerneltype != K_LINEAR) {
+ err("[GenSVM Warning]: Sparse matrices with nonlinear "
+ "kernels are not yet supported. Dense "
+ "matrices will be used.\n");
+ test_data->Z = gensvm_sparse_to_dense(test_data->spZ);
+ gensvm_free_sparse(test_data->spZ);
+ }
+
+ gensvm_kernel_postprocess(best_model, train_data, test_data);
+
+ // predict labels
+ predy = Calloc(long, test_data->n);
+ gensvm_predict_labels(test_data, best_model, predy);
+
+ if (test_data->y != NULL) {
+ performance = gensvm_prediction_perf(test_data, predy);
+ note("Predictive performance: %3.2f%%\n", performance);
+ }
+
+ // if output file is specified, write predictions to it
+ if (gensvm_check_argv_eq(argc, argv, "-o")) {
+ gensvm_write_predictions(test_data, predy,
+ prediction_outputfile);
+ note("Prediction written to: %s\n",
+ prediction_outputfile);
+ } else {
+ for (i=0; i<test_data->n; i++)
+ printf("%li ", predy[i]);
+ printf("\n");
+ }
+
+ gensvm_free_model(best_model);
+ free(predy);
}
gensvm_free_queue(q);
@@ -191,7 +264,8 @@ int main(int argc, char **argv)
* @returns seed for the RNG
*
*/
-long parse_command_line(int argc, char **argv, char *input_filename)
+long parse_command_line(int argc, char **argv, char *input_filename,
+ char **prediction_outputfile)
{
long seed = time(NULL);
int i;
@@ -204,6 +278,11 @@ long parse_command_line(int argc, char **argv, char *input_filename)
if (++i>=argc)
exit_with_help(argv);
switch (argv[i-1][1]) {
+ case 'o':
+ (*prediction_outputfile) = Malloc(char,
+ strlen(argv[i]) + 1);
+ strcpy((*prediction_outputfile), argv[i]);
+ break;
case 'q':
GENSVM_OUTPUT_FILE = NULL;
GENSVM_ERROR_FILE = NULL;
@@ -304,7 +383,6 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid)
grid->test_data_file = Calloc(char,
GENSVM_MAX_LINE_LENGTH);
strcpy(grid->test_data_file, test_filename);
- grid->traintype = TT;
} else if (str_startswith(buffer, "p:")) {
nr = all_doubles_str(buffer, 2, params);
grid->ps = Calloc(double, nr);
diff --git a/src/gensvm_consistency.c b/src/gensvm_consistency.c
index 07ff7a6..0aa33c0 100644
--- a/src/gensvm_consistency.c
+++ b/src/gensvm_consistency.c
@@ -124,8 +124,11 @@ struct GenQueue *gensvm_top_queue(struct GenQueue *q, double percentile)
* configurations for consistency
* @param[in] percentile percentile of performance to determine which
* tasks to repeat
+ *
+ * @return ID of the best task
+ *
*/
-void gensvm_consistency_repeats(struct GenQueue *q, long repeats,
+int gensvm_consistency_repeats(struct GenQueue *q, long repeats,
double percentile)
{
bool breakout;
@@ -229,11 +232,12 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats,
"mean_perf\tstd_perf\ttime_perf\n");
p = 0.0;
breakout = false;
- while (breakout == false) {
+ int best_id = -1;
+ while (!breakout) {
pi = gensvm_percentile(mean, N, (100.0-p));
pr = gensvm_percentile(std, N, p);
pt = gensvm_percentile(time, N, p);
- for (i=0; i<N; i++)
+ for (i=0; i<N; i++) {
if ((pi - mean[i] < 0.0001) &&
(std[i] - pr < 0.0001) &&
(time[i] - pt < 0.0001)) {
@@ -251,7 +255,10 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats,
std[i],
time[i]);
breakout = true;
+ if (best_id == -1)
+ best_id = nq->tasks[i]->ID;
}
+ }
p += 1.0;
}
@@ -263,6 +270,8 @@ void gensvm_consistency_repeats(struct GenQueue *q, long repeats,
free(std);
free(mean);
free(time);
+
+ return best_id;
}
/**