From ab119782aca1a2eb9216cd721bca3ab9a0235911 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 7 Dec 2016 21:05:57 +0100 Subject: allow datasets to be stored in libsvm/svmlight format --- src/GenSVMgrid.c | 15 ++- src/GenSVMtraintest.c | 25 ++++- src/gensvm_io.c | 263 +++++++++++++++++++++++++++++++++++++++++++++++++- src/gensvm_sparse.c | 28 ++++-- src/gensvm_strutil.c | 112 +++++++++++++++++++++ 5 files changed, 428 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 3799de3..41d1070 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -72,7 +72,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); - printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-x : data files are in LibSVM/SVMlight format\n"); exit(EXIT_FAILURE); } @@ -97,6 +98,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; struct GenGrid *grid = gensvm_init_grid(); @@ -108,12 +110,16 @@ int main(int argc, char **argv) || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); parse_command_line(argc, argv, input_filename); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); read_grid_from_file(input_filename, grid); note("Reading data from %s\n", grid->train_data_file); - gensvm_read_data(train_data, grid->train_data_file); + if (libsvm_format) + gensvm_read_data_libsvm(train_data, grid->train_data_file); + else + gensvm_read_data(train_data, grid->train_data_file); // check labels of training data gensvm_check_outcome_contiguous(train_data); @@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename) GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]); @@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) if (fid == NULL) { fprintf(stderr, "Error opening grid file %s\n", input_filename); - exit(1); + exit(EXIT_FAILURE); } grid->traintype = CV; while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) { diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 8bebf3a..a275e57 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -77,7 +77,6 @@ void exit_with_help(char **argv) "Huber hinge\n"); printf("-l lambda : set the value of lambda " "(lambda > 0)\n"); - printf("-s seed_model_file : use previous model as seed for V\n"); printf("-m model_output_file : write model output to file " "(not saved if no file provided)\n"); printf("-o prediction_output : write predictions of test data to " @@ -88,8 +87,11 @@ void exit_with_help(char **argv) "errors!)\n"); printf("-r rho : choose the weigth specification " "(1 = unit, 2 = group)\n"); + printf("-s seed_model_file : use previous model as seed for V\n"); printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, " "3=SIGMOID)\n"); + printf("-x : data files are in LibSVM/SVMlight " + "format\n"); printf("\n"); exit(EXIT_FAILURE); @@ -108,6 +110,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; long i, *predy = NULL; double performance; @@ -130,9 +133,15 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); + + // read data from file + if (libsvm_format) + gensvm_read_data_libsvm(traindata, training_inputfile); + else + gensvm_read_data(traindata, training_inputfile); - // read data from file and check labels - gensvm_read_data(traindata, training_inputfile); + // check labels for consistency if (!gensvm_check_outcome_contiguous(traindata)) { err("[GenSVM Error]: Class labels should start from 1 and " "have no gaps. Please reformat your data.\n"); @@ -168,7 +177,11 @@ int main(int argc, char **argv) // if we also have a test set, predict labels and write to predictions // to an output file if specified if (testing_inputfile != NULL) { - gensvm_read_data(testdata, testing_inputfile); + // read the test data + if (libsvm_format) + gensvm_read_data_libsvm(testdata, testing_inputfile); + else + gensvm_read_data(testdata, testing_inputfile); // check if we are sparse and want nonlinearity if (testdata->Z == NULL && model->kerneltype != K_LINEAR) { @@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = stderr; // parse options + // note: flags that don't have an argument should decrement i for (i=1; i=argc) { @@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: // this one should always print explicitly to // stderr, even if '-q' is supplied, because diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 0d14144..d41b2ef 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -6,7 +6,7 @@ * * @details * This file contains functions for reading and writing model files, and data - * files. It also contains a function for generating a string of the current + * files. It also contains a function for generating a string of the current * time, used in writing output files. * * @copyright @@ -38,7 +38,7 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are assumed to be in the interval [1 .. K], which can be + * The class labels are assumed to be in the interval [1 .. K], which can be * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct @@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) } } +/** + * @brief Print an error to the screen and exit (copied from LibSVM) + * + * @param[in] line_num line number where the error occured + * + */ +void exit_input_error(int line_num) +{ + err("[GenSVM Error]: Wrong input format on line: %i\n", line_num); + exit(EXIT_FAILURE); +} + +/** + * @brief Read data from a file in LibSVM/SVMlight format + * + * @details + * This function reads data from a file where the data is stored in + * LibSVM/SVMlight format. This is a sparse data format, which can be + * beneficial for certain applications. The advantage of having this function + * here is twofold: 1) existing datasets where data is stored in + * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets + * which are too large for memory when kept in dense format can be loaded + * efficiently into GenSVM. + * + * @note + * This code is based on the read_problem() function in the svm-train.c + * file of LibSVM. It has however been expanded to be able to handle data + * files without labels. + * + * @note + * This file tries to detect whether 1-based or 0-based indexing is used in + * the data file. By default 1-based indexing is used, but if an index is + * found with value 0, 0-based indexing is assumed. + * + * @sa + * gensvm_read_problem() + * + * @param[in] data GenData structure + * @param[in] data_file filename of the datafile + * + */ +void gensvm_read_data_libsvm(struct GenData *data, char *data_file) +{ + bool do_sparse, zero_based = false; + long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels, + min_index = 1; + int n_big, n_small, big_start; + double value; + FILE *fid = NULL; + char *label = NULL, + *endptr = NULL, + **big_parts = NULL, + **small_parts = NULL; + char buf[GENSVM_MAX_LINE_LENGTH]; + + fid = fopen(data_file, "r"); + if (fid == NULL) { + // LCOV_EXCL_START + err("[GenSVM Error]: Datafile %s could not be opened.\n", + data_file); + exit(EXIT_FAILURE); + // LCOV_EXCL_STOP + } + + // first count the number of elements + n = 0; + m = -1; + + num_labels = 0; + nnz = 0; + + while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) { + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + // record if this line has a label (first part has no colon) + num_labels += (!str_contains_char(big_parts[0], ':')); + + // check for each part if it is a index:value pair + for (i=0; i 0 && num_labels != n) { + err("[GenSVM Error]: There are some lines with missing " + "labels. Please fix this before " + "continuing.\n"); + exit(EXIT_FAILURE); + } + + // don't forget the column of ones + nnz += n; + + // deal with 0-based or 1-based indexing in the LibSVM file + if (min_index == 0) { + m++; + zero_based = true; + } + + // check if sparsity is worth it + do_sparse = gensvm_nnz_comparison(nnz, n, m+1); + if (do_sparse) { + data->spZ = gensvm_init_sparse(); + data->spZ->nnz = nnz; + data->spZ->n_row = n; + data->spZ->n_col = m+1; + data->spZ->values = Calloc(double, nnz); + data->spZ->ia = Calloc(long, n+1); + data->spZ->ja = Calloc(long, nnz); + data->spZ->ia[0] = 0; + } else { + data->RAW = Calloc(double, n*(m+1)); + data->Z = data->RAW; + } + if (num_labels > 0) + data->y = Calloc(long, n); + + K = 0; + cnt = 0; + for (i=0; iy[i] = tmp; + + // keep track of maximum K + K = maximum(K, data->y[i]); + + // increment big part index + big_start++; + } + + row_cnt = 0; + // set the first element in the row to 1 + if (do_sparse) { + data->spZ->values[cnt] = 1.0; + data->spZ->ja[cnt] = 0; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, 0, 1.0); + } + + // read the rest of the line + for (j=big_start; jspZ->values[cnt] = value; + data->spZ->ja[cnt] = index + zero_based; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, + index + zero_based, value); + } + + // free the small parts + free(small_parts[0]); + free(small_parts[1]); + free(small_parts); + } + + if (do_sparse) { + data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt; + } + + // free the big parts + for (j=0; jn = n; + data->m = m; + data->r = m; + data->K = K; + +} /** * @brief Read model from file diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c index ce99b3b..1a9317e 100644 --- a/src/gensvm_sparse.c +++ b/src/gensvm_sparse.c @@ -90,6 +90,23 @@ long gensvm_count_nnz(double *A, long rows, long cols) return nnz; } +/** + * @brief Compare the number of nonzeros is such that sparsity if worth it + * + * @details + * This is a utility function, see gensvm_could_sparse() for more info. + * + * @param[in] nnz number of nonzero elements + * @param[in] rows number of rows + * @param[in] cols number of columns + * + * @return whether or not sparsity is worth it + */ +bool gensvm_nnz_comparison(long nnz, long rows, long cols) +{ + return (nnz < (rows*(cols-1.0)-1.0)/2.0); +} + /** * @brief Check if it is worthwile to convert to a sparse matrix * @@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols) * the amount of nonzero entries is small enough, the function returns the * number of nonzeros. If it is too big, it returns -1. * + * @sa + * gensvm_nnz_comparison() + * * @param[in] A matrix in dense format (RowMajor order) * @param[in] rows number of rows of A * @param[in] cols number of columns of A * - * @return + * @return whether or not sparsity is worth it */ bool gensvm_could_sparse(double *A, long rows, long cols) { long nnz = gensvm_count_nnz(A, rows, cols); - - if (nnz < (rows*(cols-1.0)-1.0)/2.0) { - return true; - } - return false; + return gensvm_nnz_comparison(nnz, rows, cols); } diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c index f6bb8f7..7bca2ca 100644 --- a/src/gensvm_strutil.c +++ b/src/gensvm_strutil.c @@ -62,6 +62,118 @@ bool str_endswith(const char *str, const char *suf) lensuf) == 0; } +/** + * @brief Check if a string contains a char + * + * @details + * Simple utility function to check if a char occurs in a string. + * + * @param[in] str input string + * @param[in] c character + * + * @return number of times c occurs in str + */ +bool str_contains_char(const char *str, const char c) +{ + size_t i, len = strlen(str); + for (i=0; i