diff options
Diffstat (limited to 'src/gensvm_io.c')
| -rw-r--r-- | src/gensvm_io.c | 263 |
1 files changed, 261 insertions, 2 deletions
diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 0d14144..d41b2ef 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -6,7 +6,7 @@ * * @details * This file contains functions for reading and writing model files, and data - * files. It also contains a function for generating a string of the current + * files. It also contains a function for generating a string of the current * time, used in writing output files. * * @copyright @@ -38,7 +38,7 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are assumed to be in the interval [1 .. K], which can be + * The class labels are assumed to be in the interval [1 .. K], which can be * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct @@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) } } +/** + * @brief Print an error to the screen and exit (copied from LibSVM) + * + * @param[in] line_num line number where the error occured + * + */ +void exit_input_error(int line_num) +{ + err("[GenSVM Error]: Wrong input format on line: %i\n", line_num); + exit(EXIT_FAILURE); +} + +/** + * @brief Read data from a file in LibSVM/SVMlight format + * + * @details + * This function reads data from a file where the data is stored in + * LibSVM/SVMlight format. This is a sparse data format, which can be + * beneficial for certain applications. The advantage of having this function + * here is twofold: 1) existing datasets where data is stored in + * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets + * which are too large for memory when kept in dense format can be loaded + * efficiently into GenSVM. + * + * @note + * This code is based on the read_problem() function in the svm-train.c + * file of LibSVM. It has however been expanded to be able to handle data + * files without labels. + * + * @note + * This file tries to detect whether 1-based or 0-based indexing is used in + * the data file. By default 1-based indexing is used, but if an index is + * found with value 0, 0-based indexing is assumed. + * + * @sa + * gensvm_read_problem() + * + * @param[in] data GenData structure + * @param[in] data_file filename of the datafile + * + */ +void gensvm_read_data_libsvm(struct GenData *data, char *data_file) +{ + bool do_sparse, zero_based = false; + long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels, + min_index = 1; + int n_big, n_small, big_start; + double value; + FILE *fid = NULL; + char *label = NULL, + *endptr = NULL, + **big_parts = NULL, + **small_parts = NULL; + char buf[GENSVM_MAX_LINE_LENGTH]; + + fid = fopen(data_file, "r"); + if (fid == NULL) { + // LCOV_EXCL_START + err("[GenSVM Error]: Datafile %s could not be opened.\n", + data_file); + exit(EXIT_FAILURE); + // LCOV_EXCL_STOP + } + + // first count the number of elements + n = 0; + m = -1; + + num_labels = 0; + nnz = 0; + + while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) { + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + // record if this line has a label (first part has no colon) + num_labels += (!str_contains_char(big_parts[0], ':')); + + // check for each part if it is a index:value pair + for (i=0; i<n_big; i++) { + if (!str_contains_char(big_parts[i], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[i], ":", &n_small); + + // convert the index to a number + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // update the maximum index + m = maximum(m, index); + + // update the minimum index + min_index = minimum(min_index, index); + + // free the small parts + for (j=0; j<n_small; j++) free(small_parts[j]); + free(small_parts); + + // increment the nonzero counter + nnz++; + } + + // free the big parts + for (i=0; i<n_big; i++) { + free(big_parts[i]); + } + free(big_parts); + + // increment the number of observations + n++; + } + + // rewind the file pointer + rewind(fid); + + // check if we have enough labels + if (num_labels > 0 && num_labels != n) { + err("[GenSVM Error]: There are some lines with missing " + "labels. Please fix this before " + "continuing.\n"); + exit(EXIT_FAILURE); + } + + // don't forget the column of ones + nnz += n; + + // deal with 0-based or 1-based indexing in the LibSVM file + if (min_index == 0) { + m++; + zero_based = true; + } + + // check if sparsity is worth it + do_sparse = gensvm_nnz_comparison(nnz, n, m+1); + if (do_sparse) { + data->spZ = gensvm_init_sparse(); + data->spZ->nnz = nnz; + data->spZ->n_row = n; + data->spZ->n_col = m+1; + data->spZ->values = Calloc(double, nnz); + data->spZ->ia = Calloc(long, n+1); + data->spZ->ja = Calloc(long, nnz); + data->spZ->ia[0] = 0; + } else { + data->RAW = Calloc(double, n*(m+1)); + data->Z = data->RAW; + } + if (num_labels > 0) + data->y = Calloc(long, n); + + K = 0; + cnt = 0; + for (i=0; i<n; i++) { + fgets(buf, GENSVM_MAX_LINE_LENGTH, fid); + + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + big_start = 0; + // get the label from the first part if it exists + if (!str_contains_char(big_parts[0], ':')) { + label = strtok(big_parts[0], " \t\n"); + if (label == NULL) // empty line + exit_input_error(i+1); + + // convert the label part to a number exit if there + // are errors + tmp = strtol(label, &endptr, 10); + if (endptr == label || *endptr != '\0') + exit_input_error(i+1); + + // assign label to y + data->y[i] = tmp; + + // keep track of maximum K + K = maximum(K, data->y[i]); + + // increment big part index + big_start++; + } + + row_cnt = 0; + // set the first element in the row to 1 + if (do_sparse) { + data->spZ->values[cnt] = 1.0; + data->spZ->ja[cnt] = 0; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, 0, 1.0); + } + + // read the rest of the line + for (j=big_start; j<n_big; j++) { + if (!str_contains_char(big_parts[j], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[j], ":", &n_small); + if (n_small != 2) + exit_input_error(n+1); + + // convert the index to a long + errno = 0; + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // convert the value to a double + errno = 0; + value = strtod(small_parts[1], &endptr); + if (endptr == small_parts[1] || errno != 0 || + (*endptr != '\0' && !isspace(*endptr))) + exit_input_error(n+1); + + if (do_sparse) { + data->spZ->values[cnt] = value; + data->spZ->ja[cnt] = index + zero_based; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, + index + zero_based, value); + } + + // free the small parts + free(small_parts[0]); + free(small_parts[1]); + free(small_parts); + } + + if (do_sparse) { + data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt; + } + + // free the big parts + for (j=0; j<n_big; j++) { + free(big_parts[j]); + } + free(big_parts); + } + + fclose(fid); + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + +} /** * @brief Read model from file |
