diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/GenSVMgrid.c | 15 | ||||
| -rw-r--r-- | src/GenSVMtraintest.c | 25 | ||||
| -rw-r--r-- | src/gensvm_io.c | 263 | ||||
| -rw-r--r-- | src/gensvm_sparse.c | 28 | ||||
| -rw-r--r-- | src/gensvm_strutil.c | 112 |
5 files changed, 428 insertions, 15 deletions
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 3799de3..41d1070 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -72,7 +72,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); - printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-x : data files are in LibSVM/SVMlight format\n"); exit(EXIT_FAILURE); } @@ -97,6 +98,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; struct GenGrid *grid = gensvm_init_grid(); @@ -108,12 +110,16 @@ int main(int argc, char **argv) || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); parse_command_line(argc, argv, input_filename); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); read_grid_from_file(input_filename, grid); note("Reading data from %s\n", grid->train_data_file); - gensvm_read_data(train_data, grid->train_data_file); + if (libsvm_format) + gensvm_read_data_libsvm(train_data, grid->train_data_file); + else + gensvm_read_data(train_data, grid->train_data_file); // check labels of training data gensvm_check_outcome_contiguous(train_data); @@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename) GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]); @@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) if (fid == NULL) { fprintf(stderr, "Error opening grid file %s\n", input_filename); - exit(1); + exit(EXIT_FAILURE); } grid->traintype = CV; while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) { diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 8bebf3a..a275e57 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -77,7 +77,6 @@ void exit_with_help(char **argv) "Huber hinge\n"); printf("-l lambda : set the value of lambda " "(lambda > 0)\n"); - printf("-s seed_model_file : use previous model as seed for V\n"); printf("-m model_output_file : write model output to file " "(not saved if no file provided)\n"); printf("-o prediction_output : write predictions of test data to " @@ -88,8 +87,11 @@ void exit_with_help(char **argv) "errors!)\n"); printf("-r rho : choose the weigth specification " "(1 = unit, 2 = group)\n"); + printf("-s seed_model_file : use previous model as seed for V\n"); printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, " "3=SIGMOID)\n"); + printf("-x : data files are in LibSVM/SVMlight " + "format\n"); printf("\n"); exit(EXIT_FAILURE); @@ -108,6 +110,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; long i, *predy = NULL; double performance; @@ -130,9 +133,15 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); + + // read data from file + if (libsvm_format) + gensvm_read_data_libsvm(traindata, training_inputfile); + else + gensvm_read_data(traindata, training_inputfile); - // read data from file and check labels - gensvm_read_data(traindata, training_inputfile); + // check labels for consistency if (!gensvm_check_outcome_contiguous(traindata)) { err("[GenSVM Error]: Class labels should start from 1 and " "have no gaps. Please reformat your data.\n"); @@ -168,7 +177,11 @@ int main(int argc, char **argv) // if we also have a test set, predict labels and write to predictions // to an output file if specified if (testing_inputfile != NULL) { - gensvm_read_data(testdata, testing_inputfile); + // read the test data + if (libsvm_format) + gensvm_read_data_libsvm(testdata, testing_inputfile); + else + gensvm_read_data(testdata, testing_inputfile); // check if we are sparse and want nonlinearity if (testdata->Z == NULL && model->kerneltype != K_LINEAR) { @@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = stderr; // parse options + // note: flags that don't have an argument should decrement i for (i=1; i<argc; i++) { if (argv[i][0] != '-') break; if (++i>=argc) { @@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: // this one should always print explicitly to // stderr, even if '-q' is supplied, because diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 0d14144..d41b2ef 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -6,7 +6,7 @@ * * @details * This file contains functions for reading and writing model files, and data - * files. It also contains a function for generating a string of the current + * files. It also contains a function for generating a string of the current * time, used in writing output files. * * @copyright @@ -38,7 +38,7 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are assumed to be in the interval [1 .. K], which can be + * The class labels are assumed to be in the interval [1 .. K], which can be * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct @@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) } } +/** + * @brief Print an error to the screen and exit (copied from LibSVM) + * + * @param[in] line_num line number where the error occured + * + */ +void exit_input_error(int line_num) +{ + err("[GenSVM Error]: Wrong input format on line: %i\n", line_num); + exit(EXIT_FAILURE); +} + +/** + * @brief Read data from a file in LibSVM/SVMlight format + * + * @details + * This function reads data from a file where the data is stored in + * LibSVM/SVMlight format. This is a sparse data format, which can be + * beneficial for certain applications. The advantage of having this function + * here is twofold: 1) existing datasets where data is stored in + * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets + * which are too large for memory when kept in dense format can be loaded + * efficiently into GenSVM. + * + * @note + * This code is based on the read_problem() function in the svm-train.c + * file of LibSVM. It has however been expanded to be able to handle data + * files without labels. + * + * @note + * This file tries to detect whether 1-based or 0-based indexing is used in + * the data file. By default 1-based indexing is used, but if an index is + * found with value 0, 0-based indexing is assumed. + * + * @sa + * gensvm_read_problem() + * + * @param[in] data GenData structure + * @param[in] data_file filename of the datafile + * + */ +void gensvm_read_data_libsvm(struct GenData *data, char *data_file) +{ + bool do_sparse, zero_based = false; + long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels, + min_index = 1; + int n_big, n_small, big_start; + double value; + FILE *fid = NULL; + char *label = NULL, + *endptr = NULL, + **big_parts = NULL, + **small_parts = NULL; + char buf[GENSVM_MAX_LINE_LENGTH]; + + fid = fopen(data_file, "r"); + if (fid == NULL) { + // LCOV_EXCL_START + err("[GenSVM Error]: Datafile %s could not be opened.\n", + data_file); + exit(EXIT_FAILURE); + // LCOV_EXCL_STOP + } + + // first count the number of elements + n = 0; + m = -1; + + num_labels = 0; + nnz = 0; + + while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) { + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + // record if this line has a label (first part has no colon) + num_labels += (!str_contains_char(big_parts[0], ':')); + + // check for each part if it is a index:value pair + for (i=0; i<n_big; i++) { + if (!str_contains_char(big_parts[i], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[i], ":", &n_small); + + // convert the index to a number + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // update the maximum index + m = maximum(m, index); + + // update the minimum index + min_index = minimum(min_index, index); + + // free the small parts + for (j=0; j<n_small; j++) free(small_parts[j]); + free(small_parts); + + // increment the nonzero counter + nnz++; + } + + // free the big parts + for (i=0; i<n_big; i++) { + free(big_parts[i]); + } + free(big_parts); + + // increment the number of observations + n++; + } + + // rewind the file pointer + rewind(fid); + + // check if we have enough labels + if (num_labels > 0 && num_labels != n) { + err("[GenSVM Error]: There are some lines with missing " + "labels. Please fix this before " + "continuing.\n"); + exit(EXIT_FAILURE); + } + + // don't forget the column of ones + nnz += n; + + // deal with 0-based or 1-based indexing in the LibSVM file + if (min_index == 0) { + m++; + zero_based = true; + } + + // check if sparsity is worth it + do_sparse = gensvm_nnz_comparison(nnz, n, m+1); + if (do_sparse) { + data->spZ = gensvm_init_sparse(); + data->spZ->nnz = nnz; + data->spZ->n_row = n; + data->spZ->n_col = m+1; + data->spZ->values = Calloc(double, nnz); + data->spZ->ia = Calloc(long, n+1); + data->spZ->ja = Calloc(long, nnz); + data->spZ->ia[0] = 0; + } else { + data->RAW = Calloc(double, n*(m+1)); + data->Z = data->RAW; + } + if (num_labels > 0) + data->y = Calloc(long, n); + + K = 0; + cnt = 0; + for (i=0; i<n; i++) { + fgets(buf, GENSVM_MAX_LINE_LENGTH, fid); + + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + big_start = 0; + // get the label from the first part if it exists + if (!str_contains_char(big_parts[0], ':')) { + label = strtok(big_parts[0], " \t\n"); + if (label == NULL) // empty line + exit_input_error(i+1); + + // convert the label part to a number exit if there + // are errors + tmp = strtol(label, &endptr, 10); + if (endptr == label || *endptr != '\0') + exit_input_error(i+1); + + // assign label to y + data->y[i] = tmp; + + // keep track of maximum K + K = maximum(K, data->y[i]); + + // increment big part index + big_start++; + } + + row_cnt = 0; + // set the first element in the row to 1 + if (do_sparse) { + data->spZ->values[cnt] = 1.0; + data->spZ->ja[cnt] = 0; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, 0, 1.0); + } + + // read the rest of the line + for (j=big_start; j<n_big; j++) { + if (!str_contains_char(big_parts[j], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[j], ":", &n_small); + if (n_small != 2) + exit_input_error(n+1); + + // convert the index to a long + errno = 0; + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // convert the value to a double + errno = 0; + value = strtod(small_parts[1], &endptr); + if (endptr == small_parts[1] || errno != 0 || + (*endptr != '\0' && !isspace(*endptr))) + exit_input_error(n+1); + + if (do_sparse) { + data->spZ->values[cnt] = value; + data->spZ->ja[cnt] = index + zero_based; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, + index + zero_based, value); + } + + // free the small parts + free(small_parts[0]); + free(small_parts[1]); + free(small_parts); + } + + if (do_sparse) { + data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt; + } + + // free the big parts + for (j=0; j<n_big; j++) { + free(big_parts[j]); + } + free(big_parts); + } + + fclose(fid); + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + +} /** * @brief Read model from file diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c index ce99b3b..1a9317e 100644 --- a/src/gensvm_sparse.c +++ b/src/gensvm_sparse.c @@ -91,6 +91,23 @@ long gensvm_count_nnz(double *A, long rows, long cols) } /** + * @brief Compare the number of nonzeros is such that sparsity if worth it + * + * @details + * This is a utility function, see gensvm_could_sparse() for more info. + * + * @param[in] nnz number of nonzero elements + * @param[in] rows number of rows + * @param[in] cols number of columns + * + * @return whether or not sparsity is worth it + */ +bool gensvm_nnz_comparison(long nnz, long rows, long cols) +{ + return (nnz < (rows*(cols-1.0)-1.0)/2.0); +} + +/** * @brief Check if it is worthwile to convert to a sparse matrix * * @details @@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols) * the amount of nonzero entries is small enough, the function returns the * number of nonzeros. If it is too big, it returns -1. * + * @sa + * gensvm_nnz_comparison() + * * @param[in] A matrix in dense format (RowMajor order) * @param[in] rows number of rows of A * @param[in] cols number of columns of A * - * @return + * @return whether or not sparsity is worth it */ bool gensvm_could_sparse(double *A, long rows, long cols) { long nnz = gensvm_count_nnz(A, rows, cols); - - if (nnz < (rows*(cols-1.0)-1.0)/2.0) { - return true; - } - return false; + return gensvm_nnz_comparison(nnz, rows, cols); } diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c index f6bb8f7..7bca2ca 100644 --- a/src/gensvm_strutil.c +++ b/src/gensvm_strutil.c @@ -63,6 +63,118 @@ bool str_endswith(const char *str, const char *suf) } /** + * @brief Check if a string contains a char + * + * @details + * Simple utility function to check if a char occurs in a string. + * + * @param[in] str input string + * @param[in] c character + * + * @return number of times c occurs in str + */ +bool str_contains_char(const char *str, const char c) +{ + size_t i, len = strlen(str); + for (i=0; i<len; i++) + if (str[i] == c) + return true; + return false; +} + +/** + * @brief Count the number of times a string contains any character of another + * + * @details + * This function is used to count the number of expected parts in the function + * str_split(). It counts the number of times a character from a string of + * characters is present in an input string. + * + * @param[in] str input string + * @param[in] chars characters to count + * + * @return number of times any character from chars occurs in str + * + */ +int count_str_occurrences(const char *str, const char *chars) +{ + size_t i, j, len_str = strlen(str), + len_chars = strlen(chars); + int count = 0; + for (i=0; i<len_str; i++) { + for (j=0; j<len_chars; j++) { + count += (str[i] == chars[j]); + } + } + return count; +} + +/** + * @brief Split a string on delimiters and return an array of parts + * + * @details + * This function takes as input a string and a string of delimiters. As + * output, it gives an array of the parts of the first string, splitted on the + * characters in the second string. The input string is not changed, and the + * output contains all copies of the input string parts. + * + * @note + * The code is based on: http://stackoverflow.com/a/9210560 + * + * @param[in] original string you wish to split + * @param[in] delims string with delimiters to split on + * @param[out] len_ret length of the output array + * + * @return array of string parts + */ +char **str_split(char *original, const char *delims, int *len_ret) +{ + char *copy = NULL, + *token = NULL, + **result = NULL; + int i, count; + + size_t len = strlen(original); + size_t n_delim = strlen(delims); + + // add the null terminator to the delimiters + char all_delim[1 + n_delim]; + for (i=0; i<n_delim; i++) + all_delim[i] = delims[i]; + all_delim[n_delim] = '\0'; + + // number of occurances of the delimiters + count = count_str_occurrences(original, delims); + + // extra count in case there is a delimiter at the end + count += (str_contains_char(delims, original[len - 1])); + + // extra count for the null terminator + count++; + + // allocate the result array + result = Calloc(char *, count); + + // tokenize a copy of the original string and keep the splits + i = 0; + copy = Calloc(char, len + 1); + strcpy(copy, original); + token = strtok(copy, all_delim); + while (token) { + result[i] = Calloc(char, strlen(token) + 1); + strcpy(result[i], token); + i++; + + token = strtok(NULL, all_delim); + } + free(copy); + + *len_ret = i; + + return result; +} + +/** * @brief Move to next line in file * * @param[in] fid File opened for reading |
