aboutsummaryrefslogtreecommitdiff
path: root/src/gensvm_io.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/gensvm_io.c')
-rw-r--r--src/gensvm_io.c263
1 files changed, 261 insertions, 2 deletions
diff --git a/src/gensvm_io.c b/src/gensvm_io.c
index 0d14144..d41b2ef 100644
--- a/src/gensvm_io.c
+++ b/src/gensvm_io.c
@@ -6,7 +6,7 @@
*
* @details
* This file contains functions for reading and writing model files, and data
- * files. It also contains a function for generating a string of the current
+ * files. It also contains a function for generating a string of the current
* time, used in writing output files.
*
* @copyright
@@ -38,7 +38,7 @@
* Read the data from the data_file. The data matrix X is augmented
* with a column of ones, to get the matrix Z. The data is expected
* to follow a specific format, which is specified in the @ref spec_data_file.
- * The class labels are assumed to be in the interval [1 .. K], which can be
+ * The class labels are assumed to be in the interval [1 .. K], which can be
* checked using the function gensvm_check_outcome_contiguous().
*
* @param[in,out] dataset initialized GenData struct
@@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
}
}
+/**
+ * @brief Print an error to the screen and exit (copied from LibSVM)
+ *
+ * @param[in] line_num line number where the error occured
+ *
+ */
+void exit_input_error(int line_num)
+{
+ err("[GenSVM Error]: Wrong input format on line: %i\n", line_num);
+ exit(EXIT_FAILURE);
+}
+
+/**
+ * @brief Read data from a file in LibSVM/SVMlight format
+ *
+ * @details
+ * This function reads data from a file where the data is stored in
+ * LibSVM/SVMlight format. This is a sparse data format, which can be
+ * beneficial for certain applications. The advantage of having this function
+ * here is twofold: 1) existing datasets where data is stored in
+ * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets
+ * which are too large for memory when kept in dense format can be loaded
+ * efficiently into GenSVM.
+ *
+ * @note
+ * This code is based on the read_problem() function in the svm-train.c
+ * file of LibSVM. It has however been expanded to be able to handle data
+ * files without labels.
+ *
+ * @note
+ * This file tries to detect whether 1-based or 0-based indexing is used in
+ * the data file. By default 1-based indexing is used, but if an index is
+ * found with value 0, 0-based indexing is assumed.
+ *
+ * @sa
+ * gensvm_read_problem()
+ *
+ * @param[in] data GenData structure
+ * @param[in] data_file filename of the datafile
+ *
+ */
+void gensvm_read_data_libsvm(struct GenData *data, char *data_file)
+{
+ bool do_sparse, zero_based = false;
+ long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels,
+ min_index = 1;
+ int n_big, n_small, big_start;
+ double value;
+ FILE *fid = NULL;
+ char *label = NULL,
+ *endptr = NULL,
+ **big_parts = NULL,
+ **small_parts = NULL;
+ char buf[GENSVM_MAX_LINE_LENGTH];
+
+ fid = fopen(data_file, "r");
+ if (fid == NULL) {
+ // LCOV_EXCL_START
+ err("[GenSVM Error]: Datafile %s could not be opened.\n",
+ data_file);
+ exit(EXIT_FAILURE);
+ // LCOV_EXCL_STOP
+ }
+
+ // first count the number of elements
+ n = 0;
+ m = -1;
+
+ num_labels = 0;
+ nnz = 0;
+
+ while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) {
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ // record if this line has a label (first part has no colon)
+ num_labels += (!str_contains_char(big_parts[0], ':'));
+
+ // check for each part if it is a index:value pair
+ for (i=0; i<n_big; i++) {
+ if (!str_contains_char(big_parts[i], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[i], ":", &n_small);
+
+ // convert the index to a number
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // update the maximum index
+ m = maximum(m, index);
+
+ // update the minimum index
+ min_index = minimum(min_index, index);
+
+ // free the small parts
+ for (j=0; j<n_small; j++) free(small_parts[j]);
+ free(small_parts);
+
+ // increment the nonzero counter
+ nnz++;
+ }
+
+ // free the big parts
+ for (i=0; i<n_big; i++) {
+ free(big_parts[i]);
+ }
+ free(big_parts);
+
+ // increment the number of observations
+ n++;
+ }
+
+ // rewind the file pointer
+ rewind(fid);
+
+ // check if we have enough labels
+ if (num_labels > 0 && num_labels != n) {
+ err("[GenSVM Error]: There are some lines with missing "
+ "labels. Please fix this before "
+ "continuing.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // don't forget the column of ones
+ nnz += n;
+
+ // deal with 0-based or 1-based indexing in the LibSVM file
+ if (min_index == 0) {
+ m++;
+ zero_based = true;
+ }
+
+ // check if sparsity is worth it
+ do_sparse = gensvm_nnz_comparison(nnz, n, m+1);
+ if (do_sparse) {
+ data->spZ = gensvm_init_sparse();
+ data->spZ->nnz = nnz;
+ data->spZ->n_row = n;
+ data->spZ->n_col = m+1;
+ data->spZ->values = Calloc(double, nnz);
+ data->spZ->ia = Calloc(long, n+1);
+ data->spZ->ja = Calloc(long, nnz);
+ data->spZ->ia[0] = 0;
+ } else {
+ data->RAW = Calloc(double, n*(m+1));
+ data->Z = data->RAW;
+ }
+ if (num_labels > 0)
+ data->y = Calloc(long, n);
+
+ K = 0;
+ cnt = 0;
+ for (i=0; i<n; i++) {
+ fgets(buf, GENSVM_MAX_LINE_LENGTH, fid);
+
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ big_start = 0;
+ // get the label from the first part if it exists
+ if (!str_contains_char(big_parts[0], ':')) {
+ label = strtok(big_parts[0], " \t\n");
+ if (label == NULL) // empty line
+ exit_input_error(i+1);
+
+ // convert the label part to a number exit if there
+ // are errors
+ tmp = strtol(label, &endptr, 10);
+ if (endptr == label || *endptr != '\0')
+ exit_input_error(i+1);
+
+ // assign label to y
+ data->y[i] = tmp;
+
+ // keep track of maximum K
+ K = maximum(K, data->y[i]);
+
+ // increment big part index
+ big_start++;
+ }
+
+ row_cnt = 0;
+ // set the first element in the row to 1
+ if (do_sparse) {
+ data->spZ->values[cnt] = 1.0;
+ data->spZ->ja[cnt] = 0;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i, 0, 1.0);
+ }
+
+ // read the rest of the line
+ for (j=big_start; j<n_big; j++) {
+ if (!str_contains_char(big_parts[j], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[j], ":", &n_small);
+ if (n_small != 2)
+ exit_input_error(n+1);
+
+ // convert the index to a long
+ errno = 0;
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // convert the value to a double
+ errno = 0;
+ value = strtod(small_parts[1], &endptr);
+ if (endptr == small_parts[1] || errno != 0 ||
+ (*endptr != '\0' && !isspace(*endptr)))
+ exit_input_error(n+1);
+
+ if (do_sparse) {
+ data->spZ->values[cnt] = value;
+ data->spZ->ja[cnt] = index + zero_based;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i,
+ index + zero_based, value);
+ }
+
+ // free the small parts
+ free(small_parts[0]);
+ free(small_parts[1]);
+ free(small_parts);
+ }
+
+ if (do_sparse) {
+ data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt;
+ }
+
+ // free the big parts
+ for (j=0; j<n_big; j++) {
+ free(big_parts[j]);
+ }
+ free(big_parts);
+ }
+
+ fclose(fid);
+
+ data->n = n;
+ data->m = m;
+ data->r = m;
+ data->K = K;
+
+}
/**
* @brief Read model from file