/**
* @file GenSVMtraintest.c
* @author G.J.J. van den Burg
* @date 2015-02-01
* @brief Command line interface for training and testing with a GenSVM model
*
* @details
* This is a command line program for training and testing on a single model
* with specified model parameters.
*
* @copyright
Copyright 2016, G.J.J. van den Burg.
This file is part of GenSVM.
GenSVM is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
GenSVM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with GenSVM. If not, see .
*/
#include "gensvm_checks.h"
#include "gensvm_cmdarg.h"
#include "gensvm_io.h"
#include "gensvm_train.h"
#include "gensvm_predict.h"
/**
* Minimal number of command line arguments
*/
#define MINARGS 2
extern FILE *GENSVM_OUTPUT_FILE;
extern FILE *GENSVM_ERROR_FILE;
// function declarations
void exit_with_help(char **argv);
void parse_command_line(int argc, char **argv, struct GenModel *model,
char **model_inputfile, char **training_inputfile,
char **testing_inputfile, char **model_outputfile,
char **prediction_outputfile);
/**
* @brief Help function
*
* @details
* Print help for this program and exit. Note that the VERSION is defined in
* the Makefile.
*
* @param[in] argv command line arguments
*
*/
void exit_with_help(char **argv)
{
printf("This is GenSVM, version %s.\n", VERSION_STRING);
printf("Copyright (C) 2016, G.J.J. van den Burg.\n");
printf("This program is free software, see the LICENSE file "
"for details.\n\n");
printf("Usage: %s [options] training_data [test_data]\n\n", argv[0]);
printf("Options:\n");
printf("--------\n");
printf("-c coef : coefficient for the polynomial and "
"sigmoid kernel\n");
printf("-d degree : degree for the polynomial kernel\n");
printf("-e epsilon : set the value of the stopping "
"criterion (epsilon > 0)\n");
printf("-g gamma : parameter for the rbf, polynomial or "
"sigmoid kernel\n");
printf("-h | -help : print this help.\n");
printf("-i max_iter : maximum number of iterations to do.\n");
printf("-k kappa : set the value of kappa used in the "
"Huber hinge (kappa > -1.0)\n");
printf("-l lambda : set the value of lambda "
"(lambda > 0)\n");
printf("-m model_output_file : write model output to file "
"(not saved if no file provided)\n");
printf("-o prediction_output : write predictions of test data to "
"file (uses stdout if not provided)\n");
printf("-p p-value : set the value of p in the lp norm "
"(1.0 <= p <= 2.0)\n");
printf("-q : quiet mode (no output, not even "
"errors!)\n");
printf("-r rho : choose the weigth specification "
"(1 = unit, 2 = group)\n");
printf("-s seed_model_file : use previous model as seed for V\n");
printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, "
"3=SIGMOID)\n");
printf("-x : data files are in LibSVM/SVMlight "
"format\n");
printf("-z seed : seed for the random number generator\n");
printf("\n");
exit(EXIT_FAILURE);
}
/**
* @brief Main interface function for GenSVMtraintest
*
* @details
* Main interface for the GenSVMtraintest commandline program.
*
* @param[in] argc number of command line arguments
* @param[in] argv array of command line arguments
*
* @return exit status
*/
int main(int argc, char **argv)
{
bool libsvm_format = false;
long i, *predy = NULL;
double performance;
char *training_inputfile = NULL,
*testing_inputfile = NULL,
*model_inputfile = NULL,
*model_outputfile = NULL,
*prediction_outputfile = NULL;
struct GenModel *model = gensvm_init_model();
struct GenModel *seed_model = NULL;
struct GenData *traindata = gensvm_init_data();
struct GenData *testdata = gensvm_init_data();
if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help")
|| gensvm_check_argv_eq(argc, argv, "-h"))
exit_with_help(argv);
// parse command line arguments
parse_command_line(argc, argv, model, &model_inputfile,
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
libsvm_format = gensvm_check_argv(argc, argv, "-x");
// read data from file
if (libsvm_format)
gensvm_read_data_libsvm(traindata, training_inputfile);
else
gensvm_read_data(traindata, training_inputfile);
// check labels for consistency
if (!gensvm_check_outcome_contiguous(traindata)) {
err("[GenSVM Error]: Class labels should start from 1 and "
"have no gaps. Please reformat your data.\n");
exit(EXIT_FAILURE);
}
// save data filename to model
model->data_file = Calloc(char, GENSVM_MAX_LINE_LENGTH);
strcpy(model->data_file, training_inputfile);
// check if we are sparse and want nonlinearity
if (traindata->Z == NULL && model->kerneltype != K_LINEAR) {
err("[GenSVM Warning]: Sparse matrices with nonlinear kernels "
"are not yet supported. Dense matrices will "
"be used.\n");
traindata->RAW = gensvm_sparse_to_dense(traindata->spZ);
traindata->Z = traindata->RAW;
gensvm_free_sparse(traindata->spZ);
}
// load a seed model from file if it is specified
if (gensvm_check_argv_eq(argc, argv, "-s")) {
seed_model = gensvm_init_model();
gensvm_read_model(seed_model, model_inputfile);
}
// train the GenSVM model
gensvm_train(model, traindata, seed_model);
// if we also have a test set, predict labels and write to predictions
// to an output file if specified
if (testing_inputfile != NULL) {
// read the test data
if (libsvm_format)
gensvm_read_data_libsvm(testdata, testing_inputfile);
else
gensvm_read_data(testdata, testing_inputfile);
// check if we are sparse and want nonlinearity
if (testdata->Z == NULL && model->kerneltype != K_LINEAR) {
err("[GenSVM Warning]: Sparse matrices with nonlinear "
"kernels are not yet supported. Dense "
"matrices will be used.\n");
testdata->Z = gensvm_sparse_to_dense(testdata->spZ);
gensvm_free_sparse(testdata->spZ);
}
gensvm_kernel_postprocess(model, traindata, testdata);
// predict labels
predy = Calloc(long, testdata->n);
gensvm_predict_labels(testdata, model, predy);
if (testdata->y != NULL) {
performance = gensvm_prediction_perf(testdata, predy);
note("Predictive performance: %3.2f%%\n", performance);
}
// if output file is specified, write predictions to it
if (gensvm_check_argv_eq(argc, argv, "-o")) {
gensvm_write_predictions(testdata, predy,
prediction_outputfile);
note("Prediction written to: %s\n",
prediction_outputfile);
} else {
for (i=0; in; i++)
printf("%li ", predy[i]);
printf("\n");
}
}
// write model to output file if necessary
if (gensvm_check_argv_eq(argc, argv, "-m")) {
gensvm_write_model(model, model_outputfile);
note("Model written to: %s\n", model_outputfile);
}
// free everything
gensvm_free_model(model);
gensvm_free_model(seed_model);
gensvm_free_data(traindata);
gensvm_free_data(testdata);
free(training_inputfile);
free(testing_inputfile);
free(model_inputfile);
free(model_outputfile);
free(prediction_outputfile);
free(predy);
return 0;
}
/**
* @brief Exit with warning about invalid parameter value.
*
* @param[in] label name of the parameter
* @param[in] argv command line arguments
*/
void exit_invalid_param(const char *label, char **argv)
{
fprintf(stderr, "Invalid parameter value for %s.\n\n", label);
exit_with_help(argv);
}
/**
* @brief Parse the command line arguments
*
* @details
* For a full overview of the command line arguments and their meaning see
* exit_with_help(). This function furthermore sets the default output streams
* to stdout/stderr, and initializes the kernel parameters if none are
* supplied: gamma = 1.0, degree = 2.0, coef = 0.0.
*
* @param[in] argc number of command line arguments
* @param[in] argv array of command line arguments
* @param[in] model initialized GenModel struct
* @param[out] model_inputfile filename for the seed model
* @param[out] training_inputfile filename for the training data
* @param[out] testing_inputfile filename for the test data
* @param[out] model_outputfile filename for the output model
* @param[out] prediction_outputfile filename for the predictions
*
*/
void parse_command_line(int argc, char **argv, struct GenModel *model,
char **model_inputfile, char **training_inputfile,
char **testing_inputfile, char **model_outputfile,
char **prediction_outputfile)
{
int i;
GENSVM_OUTPUT_FILE = stdout;
GENSVM_ERROR_FILE = stderr;
// parse options
// note: flags that don't have an argument should decrement i
for (i=1; i=argc) {
exit_with_help(argv);
}
switch (argv[i-1][1]) {
case 'c':
model->coef = atof(argv[i]);
break;
case 'd':
model->degree = atof(argv[i]);
break;
case 'e':
model->epsilon = atof(argv[i]);
if (model->epsilon <= 0)
exit_invalid_param("epsilon", argv);
break;
case 'g':
model->gamma = atof(argv[i]);
break;
case 'i':
model->max_iter = atoi(argv[i]);
break;
case 'k':
model->kappa = atof(argv[i]);
if (model->kappa <= -1.0)
exit_invalid_param("kappa", argv);
break;
case 'l':
model->lambda = atof(argv[i]);
if (model->lambda <= 0)
exit_invalid_param("lambda", argv);
break;
case 's':
(*model_inputfile) = Malloc(char,
strlen(argv[i])+1);
strcpy((*model_inputfile), argv[i]);
break;
case 'm':
(*model_outputfile) = Malloc(char,
strlen(argv[i])+1);
strcpy((*model_outputfile), argv[i]);
break;
case 'o':
(*prediction_outputfile) = Malloc(char,
strlen(argv[i])+1);
strcpy((*prediction_outputfile), argv[i]);
break;
case 'p':
model->p = atof(argv[i]);
if (model->p < 1.0 || model->p > 2.0)
exit_invalid_param("p", argv);
break;
case 'r':
model->weight_idx = atoi(argv[i]);
break;
case 't':
model->kerneltype = atoi(argv[i]);
break;
case 'q':
GENSVM_OUTPUT_FILE = NULL;
GENSVM_ERROR_FILE = NULL;
i--;
break;
case 'x':
i--;
break;
case 'z':
model->seed = atoi(argv[i]);
break;
default:
// this one should always print explicitly to
// stderr, even if '-q' is supplied, because
// otherwise you can't debug cmdline flags.
fprintf(stderr, "Unknown option: -%c\n",
argv[i-1][1]);
exit_with_help(argv);
}
}
if (i >= argc)
exit_with_help(argv);
(*training_inputfile) = Malloc(char, strlen(argv[i])+1);
strcpy((*training_inputfile), argv[i]);
if (i+2 == argc) {
(*testing_inputfile) = Malloc(char, strlen(argv[i])+1);
strcpy((*testing_inputfile), argv[i+1]);
}
}