aboutsummaryrefslogtreecommitdiff
path: root/src/GenSVMgrid.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenSVMgrid.c')
-rw-r--r--src/GenSVMgrid.c108
1 files changed, 93 insertions, 15 deletions
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 782200f..ab209cc 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -39,9 +39,10 @@
#include "gensvm_checks.h"
#include "gensvm_cmdarg.h"
+#include "gensvm_consistency.h"
#include "gensvm_io.h"
#include "gensvm_gridsearch.h"
-#include "gensvm_consistency.h"
+#include "gensvm_train.h"
/**
* Minimal number of command line arguments
@@ -53,7 +54,8 @@ extern FILE *GENSVM_ERROR_FILE;
// function declarations
void exit_with_help(char **argv);
-long parse_command_line(int argc, char **argv, char *input_filename);
+long parse_command_line(int argc, char **argv, char *input_filename,
+ char **prediction_outputfile);
void read_grid_from_file(char *input_filename, struct GenGrid *grid);
/**
@@ -75,6 +77,8 @@ void exit_with_help(char **argv)
printf("Usage: %s [options] grid_file\n", argv[0]);
printf("Options:\n");
printf("-h | -help : print this help.\n");
+ printf("-o prediction_output : write predictions of test data to "
+ "file (uses stdout if not provided)\n");
printf("-q : quiet mode (no output, not even errors!)\n");
printf("-x : data files are in LibSVM/SVMlight format\n");
printf("-z : seed for the random number generator\n");
@@ -102,9 +106,11 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ int i, best_ID = -1;
long seed;
bool libsvm_format = false;
char input_filename[GENSVM_MAX_LINE_LENGTH];
+ char *prediction_outputfile = NULL;
struct GenGrid *grid = gensvm_init_grid();
struct GenData *train_data = gensvm_init_data();
@@ -114,7 +120,8 @@ int main(int argc, char **argv)
if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help")
|| gensvm_check_argv_eq(argc, argv, "-h") )
exit_with_help(argv);
- seed = parse_command_line(argc, argv, input_filename);
+ seed = parse_command_line(argc, argv, input_filename,
+ &prediction_outputfile);
libsvm_format = gensvm_check_argv(argc, argv, "-x");
note("Reading grid file\n");
@@ -126,6 +133,18 @@ int main(int argc, char **argv)
else
gensvm_read_data(train_data, grid->train_data_file);
+ // Read the test data if present
+ if (grid->test_data_file != NULL) {
+ if (libsvm_format)
+ gensvm_read_data_libsvm(test_data,
+ grid->test_data_file);
+ else
+ gensvm_read_data(test_data, grid->test_data_file);
+ } else {
+ gensvm_free_data(test_data);
+ test_data = NULL;
+ }
+
// check labels of training data
gensvm_check_outcome_contiguous(train_data);
if (!gensvm_check_outcome_contiguous(train_data)) {
@@ -144,15 +163,6 @@ int main(int argc, char **argv)
gensvm_free_sparse(train_data->spZ);
}
- if (grid->traintype == TT) {
- err("[GenSVM Warning]: Using test datasets in a grid search "
- "is not yet supported in GenSVM.\n"
- " The test dataset will be "
- "ignored during training.\n");
- //note("Reading data from %s\n", grid->test_data_file);
- //gensvm_read_data(test_data, grid->test_data_file);
- }
-
note("Creating queue\n");
gensvm_fill_queue(grid, q, train_data, test_data);
@@ -163,7 +173,70 @@ int main(int argc, char **argv)
note("Training finished\n");
if (grid->repeats > 0) {
- gensvm_consistency_repeats(q, grid->repeats, grid->percentile);
+ best_ID = gensvm_consistency_repeats(q, grid->repeats,
+ grid->percentile);
+ } else {
+ double maxperf = -1;
+ for (i=0; i<q->N; i++) {
+ if (q->tasks[i]->performance > maxperf) {
+ maxperf = q->tasks[i]->performance;
+ best_ID = q->tasks[i]->ID;
+ }
+ }
+ }
+
+ // If we have test data, train best model on training data and predict
+ // test data
+ if (test_data) {
+ struct GenTask *best_task = NULL;
+ struct GenModel *best_model = NULL;
+ long *predy = NULL;
+ double performance = -1;
+
+ for (i=0; i<q->N; i++)
+ if (q->tasks[i]->ID == best_ID)
+ best_task = q->tasks[i];
+
+ best_model = gensvm_init_model();
+ gensvm_task_to_model(best_task, best_model);
+
+ gensvm_train(best_model, train_data, NULL);
+
+ // check if we are sparse and want nonlinearity
+ if (test_data->Z == NULL &&
+ best_model->kerneltype != K_LINEAR) {
+ err("[GenSVM Warning]: Sparse matrices with nonlinear "
+ "kernels are not yet supported. Dense "
+ "matrices will be used.\n");
+ test_data->Z = gensvm_sparse_to_dense(test_data->spZ);
+ gensvm_free_sparse(test_data->spZ);
+ }
+
+ gensvm_kernel_postprocess(best_model, train_data, test_data);
+
+ // predict labels
+ predy = Calloc(long, test_data->n);
+ gensvm_predict_labels(test_data, best_model, predy);
+
+ if (test_data->y != NULL) {
+ performance = gensvm_prediction_perf(test_data, predy);
+ note("Predictive performance: %3.2f%%\n", performance);
+ }
+
+ // if output file is specified, write predictions to it
+ if (gensvm_check_argv_eq(argc, argv, "-o")) {
+ gensvm_write_predictions(test_data, predy,
+ prediction_outputfile);
+ note("Prediction written to: %s\n",
+ prediction_outputfile);
+ } else {
+ for (i=0; i<test_data->n; i++)
+ printf("%li ", predy[i]);
+ printf("\n");
+ }
+
+ gensvm_free_model(best_model);
+ free(predy);
}
gensvm_free_queue(q);
@@ -191,7 +264,8 @@ int main(int argc, char **argv)
* @returns seed for the RNG
*
*/
-long parse_command_line(int argc, char **argv, char *input_filename)
+long parse_command_line(int argc, char **argv, char *input_filename,
+ char **prediction_outputfile)
{
long seed = time(NULL);
int i;
@@ -204,6 +278,11 @@ long parse_command_line(int argc, char **argv, char *input_filename)
if (++i>=argc)
exit_with_help(argv);
switch (argv[i-1][1]) {
+ case 'o':
+ (*prediction_outputfile) = Malloc(char,
+ strlen(argv[i]) + 1);
+ strcpy((*prediction_outputfile), argv[i]);
+ break;
case 'q':
GENSVM_OUTPUT_FILE = NULL;
GENSVM_ERROR_FILE = NULL;
@@ -304,7 +383,6 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid)
grid->test_data_file = Calloc(char,
GENSVM_MAX_LINE_LENGTH);
strcpy(grid->test_data_file, test_filename);
- grid->traintype = TT;
} else if (str_startswith(buffer, "p:")) {
nr = all_doubles_str(buffer, 2, params);
grid->ps = Calloc(double, nr);