aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2016-12-07 21:05:57 +0100
committerGertjan van den Burg <burg@ese.eur.nl>2016-12-07 21:05:57 +0100
commitab119782aca1a2eb9216cd721bca3ab9a0235911 (patch)
tree6701667265066758f8150dc5ba29ceee0c28ca83
parentmoved check for class labels to seperate module (diff)
downloadgensvm-ab119782aca1a2eb9216cd721bca3ab9a0235911.tar.gz
gensvm-ab119782aca1a2eb9216cd721bca3ab9a0235911.zip
allow datasets to be stored in libsvm/svmlight format
-rw-r--r--include/gensvm_globals.h10
-rw-r--r--include/gensvm_io.h1
-rw-r--r--include/gensvm_sparse.h1
-rw-r--r--include/gensvm_strutil.h2
-rw-r--r--src/GenSVMgrid.c15
-rw-r--r--src/GenSVMtraintest.c25
-rw-r--r--src/gensvm_io.c263
-rw-r--r--src/gensvm_sparse.c28
-rw-r--r--src/gensvm_strutil.c112
-rw-r--r--tests/data/test_file_read_data_libsvm.txt5
-rw-r--r--tests/data/test_file_read_data_libsvm_0.txt5
-rw-r--r--tests/data/test_file_read_data_no_label_libsvm.txt5
-rw-r--r--tests/data/test_file_read_data_sparse_libsvm.txt10
-rw-r--r--tests/src/test_gensvm_io.c320
14 files changed, 783 insertions, 19 deletions
diff --git a/include/gensvm_globals.h b/include/gensvm_globals.h
index 1151069..1aca458 100644
--- a/include/gensvm_globals.h
+++ b/include/gensvm_globals.h
@@ -39,15 +39,17 @@
#include "gensvm_memory.h"
// all system libraries are included here
+#include <cblas.h>
+#include <ctype.h>
+#include <errno.h>
+#include <limits.h>
+#include <math.h>
#include <stdarg.h>
+#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
-#include <stdbool.h>
#include <string.h>
-#include <math.h>
#include <time.h>
-#include <cblas.h>
-#include <limits.h>
// ########################### Type definitions ########################### //
diff --git a/include/gensvm_io.h b/include/gensvm_io.h
index afcf4a3..c18e9a2 100644
--- a/include/gensvm_io.h
+++ b/include/gensvm_io.h
@@ -37,6 +37,7 @@
// function declarations
void gensvm_read_data(struct GenData *dataset, char *data_file);
+void gensvm_read_data_libsvm(struct GenData *dataset, char *data_file);
void gensvm_read_model(struct GenModel *model, char *model_filename);
void gensvm_write_model(struct GenModel *model, char *output_filename);
diff --git a/include/gensvm_sparse.h b/include/gensvm_sparse.h
index 570e19a..f897889 100644
--- a/include/gensvm_sparse.h
+++ b/include/gensvm_sparse.h
@@ -71,6 +71,7 @@ struct GenSparse {
struct GenSparse *gensvm_init_sparse();
void gensvm_free_sparse(struct GenSparse *sp);
long gensvm_count_nnz(double *A, long rows, long cols);
+bool gensvm_nnz_comparison(long nnz, long rows, long cols);
bool gensvm_could_sparse(double *A, long rows, long cols);
struct GenSparse *gensvm_dense_to_sparse(double *A, long rows, long cols);
double *gensvm_sparse_to_dense(struct GenSparse *A);
diff --git a/include/gensvm_strutil.h b/include/gensvm_strutil.h
index 4e86944..1a2dccc 100644
--- a/include/gensvm_strutil.h
+++ b/include/gensvm_strutil.h
@@ -35,6 +35,8 @@
bool str_startswith(const char *str, const char *pre);
bool str_endswith(const char *str, const char *suf);
+bool str_contains_char(const char *str, const char c);
+char **str_split(char *original, const char *delims, int *len_ret);
void next_line(FILE *fid, char *filename);
char *get_line(FILE *fid, char *filename, char *buffer);
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 3799de3..41d1070 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -72,7 +72,8 @@ void exit_with_help(char **argv)
printf("Usage: %s [options] grid_file\n", argv[0]);
printf("Options:\n");
printf("-h | -help : print this help.\n");
- printf("-q : quiet mode (no output, not even errors!)\n");
+ printf("-q : quiet mode (no output, not even errors!)\n");
+ printf("-x : data files are in LibSVM/SVMlight format\n");
exit(EXIT_FAILURE);
}
@@ -97,6 +98,7 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ bool libsvm_format = false;
char input_filename[GENSVM_MAX_LINE_LENGTH];
struct GenGrid *grid = gensvm_init_grid();
@@ -108,12 +110,16 @@ int main(int argc, char **argv)
|| gensvm_check_argv_eq(argc, argv, "-h") )
exit_with_help(argv);
parse_command_line(argc, argv, input_filename);
+ libsvm_format = gensvm_check_argv(argc, argv, "-x");
note("Reading grid file\n");
read_grid_from_file(input_filename, grid);
note("Reading data from %s\n", grid->train_data_file);
- gensvm_read_data(train_data, grid->train_data_file);
+ if (libsvm_format)
+ gensvm_read_data_libsvm(train_data, grid->train_data_file);
+ else
+ gensvm_read_data(train_data, grid->train_data_file);
// check labels of training data
gensvm_check_outcome_contiguous(train_data);
@@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename)
GENSVM_ERROR_FILE = NULL;
i--;
break;
+ case 'x':
+ i--;
+ break;
default:
fprintf(stderr, "Unknown option: -%c\n",
argv[i-1][1]);
@@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid)
if (fid == NULL) {
fprintf(stderr, "Error opening grid file %s\n",
input_filename);
- exit(1);
+ exit(EXIT_FAILURE);
}
grid->traintype = CV;
while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) {
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index 8bebf3a..a275e57 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -77,7 +77,6 @@ void exit_with_help(char **argv)
"Huber hinge\n");
printf("-l lambda : set the value of lambda "
"(lambda > 0)\n");
- printf("-s seed_model_file : use previous model as seed for V\n");
printf("-m model_output_file : write model output to file "
"(not saved if no file provided)\n");
printf("-o prediction_output : write predictions of test data to "
@@ -88,8 +87,11 @@ void exit_with_help(char **argv)
"errors!)\n");
printf("-r rho : choose the weigth specification "
"(1 = unit, 2 = group)\n");
+ printf("-s seed_model_file : use previous model as seed for V\n");
printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, "
"3=SIGMOID)\n");
+ printf("-x : data files are in LibSVM/SVMlight "
+ "format\n");
printf("\n");
exit(EXIT_FAILURE);
@@ -108,6 +110,7 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ bool libsvm_format = false;
long i, *predy = NULL;
double performance;
@@ -130,9 +133,15 @@ int main(int argc, char **argv)
parse_command_line(argc, argv, model, &model_inputfile,
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
+ libsvm_format = gensvm_check_argv(argc, argv, "-x");
+
+ // read data from file
+ if (libsvm_format)
+ gensvm_read_data_libsvm(traindata, training_inputfile);
+ else
+ gensvm_read_data(traindata, training_inputfile);
- // read data from file and check labels
- gensvm_read_data(traindata, training_inputfile);
+ // check labels for consistency
if (!gensvm_check_outcome_contiguous(traindata)) {
err("[GenSVM Error]: Class labels should start from 1 and "
"have no gaps. Please reformat your data.\n");
@@ -168,7 +177,11 @@ int main(int argc, char **argv)
// if we also have a test set, predict labels and write to predictions
// to an output file if specified
if (testing_inputfile != NULL) {
- gensvm_read_data(testdata, testing_inputfile);
+ // read the test data
+ if (libsvm_format)
+ gensvm_read_data_libsvm(testdata, testing_inputfile);
+ else
+ gensvm_read_data(testdata, testing_inputfile);
// check if we are sparse and want nonlinearity
if (testdata->Z == NULL && model->kerneltype != K_LINEAR) {
@@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model,
GENSVM_ERROR_FILE = stderr;
// parse options
+ // note: flags that don't have an argument should decrement i
for (i=1; i<argc; i++) {
if (argv[i][0] != '-') break;
if (++i>=argc) {
@@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model,
GENSVM_ERROR_FILE = NULL;
i--;
break;
+ case 'x':
+ i--;
+ break;
default:
// this one should always print explicitly to
// stderr, even if '-q' is supplied, because
diff --git a/src/gensvm_io.c b/src/gensvm_io.c
index 0d14144..d41b2ef 100644
--- a/src/gensvm_io.c
+++ b/src/gensvm_io.c
@@ -6,7 +6,7 @@
*
* @details
* This file contains functions for reading and writing model files, and data
- * files. It also contains a function for generating a string of the current
+ * files. It also contains a function for generating a string of the current
* time, used in writing output files.
*
* @copyright
@@ -38,7 +38,7 @@
* Read the data from the data_file. The data matrix X is augmented
* with a column of ones, to get the matrix Z. The data is expected
* to follow a specific format, which is specified in the @ref spec_data_file.
- * The class labels are assumed to be in the interval [1 .. K], which can be
+ * The class labels are assumed to be in the interval [1 .. K], which can be
* checked using the function gensvm_check_outcome_contiguous().
*
* @param[in,out] dataset initialized GenData struct
@@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
}
}
+/**
+ * @brief Print an error to the screen and exit (copied from LibSVM)
+ *
+ * @param[in] line_num line number where the error occured
+ *
+ */
+void exit_input_error(int line_num)
+{
+ err("[GenSVM Error]: Wrong input format on line: %i\n", line_num);
+ exit(EXIT_FAILURE);
+}
+
+/**
+ * @brief Read data from a file in LibSVM/SVMlight format
+ *
+ * @details
+ * This function reads data from a file where the data is stored in
+ * LibSVM/SVMlight format. This is a sparse data format, which can be
+ * beneficial for certain applications. The advantage of having this function
+ * here is twofold: 1) existing datasets where data is stored in
+ * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets
+ * which are too large for memory when kept in dense format can be loaded
+ * efficiently into GenSVM.
+ *
+ * @note
+ * This code is based on the read_problem() function in the svm-train.c
+ * file of LibSVM. It has however been expanded to be able to handle data
+ * files without labels.
+ *
+ * @note
+ * This file tries to detect whether 1-based or 0-based indexing is used in
+ * the data file. By default 1-based indexing is used, but if an index is
+ * found with value 0, 0-based indexing is assumed.
+ *
+ * @sa
+ * gensvm_read_problem()
+ *
+ * @param[in] data GenData structure
+ * @param[in] data_file filename of the datafile
+ *
+ */
+void gensvm_read_data_libsvm(struct GenData *data, char *data_file)
+{
+ bool do_sparse, zero_based = false;
+ long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels,
+ min_index = 1;
+ int n_big, n_small, big_start;
+ double value;
+ FILE *fid = NULL;
+ char *label = NULL,
+ *endptr = NULL,
+ **big_parts = NULL,
+ **small_parts = NULL;
+ char buf[GENSVM_MAX_LINE_LENGTH];
+
+ fid = fopen(data_file, "r");
+ if (fid == NULL) {
+ // LCOV_EXCL_START
+ err("[GenSVM Error]: Datafile %s could not be opened.\n",
+ data_file);
+ exit(EXIT_FAILURE);
+ // LCOV_EXCL_STOP
+ }
+
+ // first count the number of elements
+ n = 0;
+ m = -1;
+
+ num_labels = 0;
+ nnz = 0;
+
+ while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) {
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ // record if this line has a label (first part has no colon)
+ num_labels += (!str_contains_char(big_parts[0], ':'));
+
+ // check for each part if it is a index:value pair
+ for (i=0; i<n_big; i++) {
+ if (!str_contains_char(big_parts[i], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[i], ":", &n_small);
+
+ // convert the index to a number
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // update the maximum index
+ m = maximum(m, index);
+
+ // update the minimum index
+ min_index = minimum(min_index, index);
+
+ // free the small parts
+ for (j=0; j<n_small; j++) free(small_parts[j]);
+ free(small_parts);
+
+ // increment the nonzero counter
+ nnz++;
+ }
+
+ // free the big parts
+ for (i=0; i<n_big; i++) {
+ free(big_parts[i]);
+ }
+ free(big_parts);
+
+ // increment the number of observations
+ n++;
+ }
+
+ // rewind the file pointer
+ rewind(fid);
+
+ // check if we have enough labels
+ if (num_labels > 0 && num_labels != n) {
+ err("[GenSVM Error]: There are some lines with missing "
+ "labels. Please fix this before "
+ "continuing.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // don't forget the column of ones
+ nnz += n;
+
+ // deal with 0-based or 1-based indexing in the LibSVM file
+ if (min_index == 0) {
+ m++;
+ zero_based = true;
+ }
+
+ // check if sparsity is worth it
+ do_sparse = gensvm_nnz_comparison(nnz, n, m+1);
+ if (do_sparse) {
+ data->spZ = gensvm_init_sparse();
+ data->spZ->nnz = nnz;
+ data->spZ->n_row = n;
+ data->spZ->n_col = m+1;
+ data->spZ->values = Calloc(double, nnz);
+ data->spZ->ia = Calloc(long, n+1);
+ data->spZ->ja = Calloc(long, nnz);
+ data->spZ->ia[0] = 0;
+ } else {
+ data->RAW = Calloc(double, n*(m+1));
+ data->Z = data->RAW;
+ }
+ if (num_labels > 0)
+ data->y = Calloc(long, n);
+
+ K = 0;
+ cnt = 0;
+ for (i=0; i<n; i++) {
+ fgets(buf, GENSVM_MAX_LINE_LENGTH, fid);
+
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ big_start = 0;
+ // get the label from the first part if it exists
+ if (!str_contains_char(big_parts[0], ':')) {
+ label = strtok(big_parts[0], " \t\n");
+ if (label == NULL) // empty line
+ exit_input_error(i+1);
+
+ // convert the label part to a number exit if there
+ // are errors
+ tmp = strtol(label, &endptr, 10);
+ if (endptr == label || *endptr != '\0')
+ exit_input_error(i+1);
+
+ // assign label to y
+ data->y[i] = tmp;
+
+ // keep track of maximum K
+ K = maximum(K, data->y[i]);
+
+ // increment big part index
+ big_start++;
+ }
+
+ row_cnt = 0;
+ // set the first element in the row to 1
+ if (do_sparse) {
+ data->spZ->values[cnt] = 1.0;
+ data->spZ->ja[cnt] = 0;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i, 0, 1.0);
+ }
+
+ // read the rest of the line
+ for (j=big_start; j<n_big; j++) {
+ if (!str_contains_char(big_parts[j], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[j], ":", &n_small);
+ if (n_small != 2)
+ exit_input_error(n+1);
+
+ // convert the index to a long
+ errno = 0;
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // convert the value to a double
+ errno = 0;
+ value = strtod(small_parts[1], &endptr);
+ if (endptr == small_parts[1] || errno != 0 ||
+ (*endptr != '\0' && !isspace(*endptr)))
+ exit_input_error(n+1);
+
+ if (do_sparse) {
+ data->spZ->values[cnt] = value;
+ data->spZ->ja[cnt] = index + zero_based;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i,
+ index + zero_based, value);
+ }
+
+ // free the small parts
+ free(small_parts[0]);
+ free(small_parts[1]);
+ free(small_parts);
+ }
+
+ if (do_sparse) {
+ data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt;
+ }
+
+ // free the big parts
+ for (j=0; j<n_big; j++) {
+ free(big_parts[j]);
+ }
+ free(big_parts);
+ }
+
+ fclose(fid);
+
+ data->n = n;
+ data->m = m;
+ data->r = m;
+ data->K = K;
+
+}
/**
* @brief Read model from file
diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c
index ce99b3b..1a9317e 100644
--- a/src/gensvm_sparse.c
+++ b/src/gensvm_sparse.c
@@ -91,6 +91,23 @@ long gensvm_count_nnz(double *A, long rows, long cols)
}
/**
+ * @brief Compare the number of nonzeros is such that sparsity if worth it
+ *
+ * @details
+ * This is a utility function, see gensvm_could_sparse() for more info.
+ *
+ * @param[in] nnz number of nonzero elements
+ * @param[in] rows number of rows
+ * @param[in] cols number of columns
+ *
+ * @return whether or not sparsity is worth it
+ */
+bool gensvm_nnz_comparison(long nnz, long rows, long cols)
+{
+ return (nnz < (rows*(cols-1.0)-1.0)/2.0);
+}
+
+/**
* @brief Check if it is worthwile to convert to a sparse matrix
*
* @details
@@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols)
* the amount of nonzero entries is small enough, the function returns the
* number of nonzeros. If it is too big, it returns -1.
*
+ * @sa
+ * gensvm_nnz_comparison()
+ *
* @param[in] A matrix in dense format (RowMajor order)
* @param[in] rows number of rows of A
* @param[in] cols number of columns of A
*
- * @return
+ * @return whether or not sparsity is worth it
*/
bool gensvm_could_sparse(double *A, long rows, long cols)
{
long nnz = gensvm_count_nnz(A, rows, cols);
-
- if (nnz < (rows*(cols-1.0)-1.0)/2.0) {
- return true;
- }
- return false;
+ return gensvm_nnz_comparison(nnz, rows, cols);
}
diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c
index f6bb8f7..7bca2ca 100644
--- a/src/gensvm_strutil.c
+++ b/src/gensvm_strutil.c
@@ -63,6 +63,118 @@ bool str_endswith(const char *str, const char *suf)
}
/**
+ * @brief Check if a string contains a char
+ *
+ * @details
+ * Simple utility function to check if a char occurs in a string.
+ *
+ * @param[in] str input string
+ * @param[in] c character
+ *
+ * @return number of times c occurs in str
+ */
+bool str_contains_char(const char *str, const char c)
+{
+ size_t i, len = strlen(str);
+ for (i=0; i<len; i++)
+ if (str[i] == c)
+ return true;
+ return false;
+}
+
+/**
+ * @brief Count the number of times a string contains any character of another
+ *
+ * @details
+ * This function is used to count the number of expected parts in the function
+ * str_split(). It counts the number of times a character from a string of
+ * characters is present in an input string.
+ *
+ * @param[in] str input string
+ * @param[in] chars characters to count
+ *
+ * @return number of times any character from chars occurs in str
+ *
+ */
+int count_str_occurrences(const char *str, const char *chars)
+{
+ size_t i, j, len_str = strlen(str),
+ len_chars = strlen(chars);
+ int count = 0;
+ for (i=0; i<len_str; i++) {
+ for (j=0; j<len_chars; j++) {
+ count += (str[i] == chars[j]);
+ }
+ }
+ return count;
+}
+
+/**
+ * @brief Split a string on delimiters and return an array of parts
+ *
+ * @details
+ * This function takes as input a string and a string of delimiters. As
+ * output, it gives an array of the parts of the first string, splitted on the
+ * characters in the second string. The input string is not changed, and the
+ * output contains all copies of the input string parts.
+ *
+ * @note
+ * The code is based on: http://stackoverflow.com/a/9210560
+ *
+ * @param[in] original string you wish to split
+ * @param[in] delims string with delimiters to split on
+ * @param[out] len_ret length of the output array
+ *
+ * @return array of string parts
+ */
+char **str_split(char *original, const char *delims, int *len_ret)
+{
+ char *copy = NULL,
+ *token = NULL,
+ **result = NULL;
+ int i, count;
+
+ size_t len = strlen(original);
+ size_t n_delim = strlen(delims);
+
+ // add the null terminator to the delimiters
+ char all_delim[1 + n_delim];
+ for (i=0; i<n_delim; i++)
+ all_delim[i] = delims[i];
+ all_delim[n_delim] = '\0';
+
+ // number of occurances of the delimiters
+ count = count_str_occurrences(original, delims);
+
+ // extra count in case there is a delimiter at the end
+ count += (str_contains_char(delims, original[len - 1]));
+
+ // extra count for the null terminator
+ count++;
+
+ // allocate the result array
+ result = Calloc(char *, count);
+
+ // tokenize a copy of the original string and keep the splits
+ i = 0;
+ copy = Calloc(char, len + 1);
+ strcpy(copy, original);
+ token = strtok(copy, all_delim);
+ while (token) {
+ result[i] = Calloc(char, strlen(token) + 1);
+ strcpy(result[i], token);
+ i++;
+
+ token = strtok(NULL, all_delim);
+ }
+ free(copy);
+
+ *len_ret = i;
+
+ return result;
+}
+
+/**
* @brief Move to next line in file
*
* @param[in] fid File opened for reading
diff --git a/tests/data/test_file_read_data_libsvm.txt b/tests/data/test_file_read_data_libsvm.txt
new file mode 100644
index 0000000..da1c0ff
--- /dev/null
+++ b/tests/data/test_file_read_data_libsvm.txt
@@ -0,0 +1,5 @@
+2 1:0.7065937536993949 2:0.7016517970438980 3:0.1548611397288129
+1 1:0.4604987687863951 2:0.6374142980176117 3:0.0370930278245423
+3 1:0.3798777132278375 2:0.5745070018747664 3:0.2570906697837264
+4 1:0.2789376050039792 2:0.4853242744610165 3:0.1894010436762711
+3 1:0.7630904372339489 2:0.1341546320318005 3:0.6827430912944857
diff --git a/tests/data/test_file_read_data_libsvm_0.txt b/tests/data/test_file_read_data_libsvm_0.txt
new file mode 100644
index 0000000..6169a0c
--- /dev/null
+++ b/tests/data/test_file_read_data_libsvm_0.txt
@@ -0,0 +1,5 @@
+2 0:0.7065937536993949 1:0.7016517970438980 2:0.1548611397288129
+1 0:0.4604987687863951 1:0.6374142980176117 2:0.0370930278245423
+3 0:0.3798777132278375 1:0.5745070018747664 2:0.2570906697837264
+4 0:0.2789376050039792 1:0.4853242744610165 2:0.1894010436762711
+3 0:0.7630904372339489 1:0.1341546320318005 2:0.6827430912944857
diff --git a/tests/data/test_file_read_data_no_label_libsvm.txt b/tests/data/test_file_read_data_no_label_libsvm.txt
new file mode 100644
index 0000000..5452216
--- /dev/null
+++ b/tests/data/test_file_read_data_no_label_libsvm.txt
@@ -0,0 +1,5 @@
+1:0.7065937536993949 2:0.7016517970438980 3:0.1548611397288129
+1:0.4604987687863951 2:0.6374142980176117 3:0.0370930278245423
+1:0.3798777132278375 2:0.5745070018747664 3:0.2570906697837264
+1:0.2789376050039792 2:0.4853242744610165 3:0.1894010436762711
+1:0.7630904372339489 2:0.1341546320318005 3:0.6827430912944857
diff --git a/tests/data/test_file_read_data_sparse_libsvm.txt b/tests/data/test_file_read_data_sparse_libsvm.txt
new file mode 100644
index 0000000..ff41fe7
--- /dev/null
+++ b/tests/data/test_file_read_data_sparse_libsvm.txt
@@ -0,0 +1,10 @@
+2 2:0.7016517970438980
+1 3:0.0370930278245423
+3
+4 2:0.4853242744610165
+3 1:0.7630904372339489
+1
+2
+3
+4
+1
diff --git a/tests/src/test_gensvm_io.c b/tests/src/test_gensvm_io.c
index f1a7d22..29776d6 100644
--- a/tests/src/test_gensvm_io.c
+++ b/tests/src/test_gensvm_io.c
@@ -265,6 +265,321 @@ char *test_gensvm_read_data_no_label()
return NULL;
}
+char *test_gensvm_read_data_libsvm()
+{
+ char *filename = "./data/test_file_read_data_libsvm.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data_libsvm(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 5, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 4, "Incorrect value for K");
+
+ // check if all data is read correctly.
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949,
+ "Incorrect Z value at 0, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980,
+ "Incorrect Z value at 0, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129,
+ "Incorrect Z value at 0, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951,
+ "Incorrect Z value at 1, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117,
+ "Incorrect Z value at 1, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423,
+ "Incorrect Z value at 1, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375,
+ "Incorrect Z value at 2, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664,
+ "Incorrect Z value at 2, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264,
+ "Incorrect Z value at 2, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792,
+ "Incorrect Z value at 3, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165,
+ "Incorrect Z value at 3, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711,
+ "Incorrect Z value at 3, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489,
+ "Incorrect Z value at 4, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005,
+ "Incorrect Z value at 4, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857,
+ "Incorrect Z value at 4, 3");
+ // check if RAW = Z
+ mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer");
+
+ // check if labels read correctly
+ mu_assert(data->y[0] == 2, "Incorrect label read at 0");
+ mu_assert(data->y[1] == 1, "Incorrect label read at 1");
+ mu_assert(data->y[2] == 3, "Incorrect label read at 2");
+ mu_assert(data->y[3] == 4, "Incorrect label read at 3");
+ mu_assert(data->y[4] == 3, "Incorrect label read at 4");
+
+ // check if the column of ones is added
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1,
+ "Incorrect Z value at 0, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1,
+ "Incorrect Z value at 1, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1,
+ "Incorrect Z value at 2, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1,
+ "Incorrect Z value at 3, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1,
+ "Incorrect Z value at 4, 0");
+
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_gensvm_read_data_libsvm_0based()
+{
+ char *filename = "./data/test_file_read_data_libsvm_0.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data_libsvm(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 5, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 4, "Incorrect value for K");
+
+ // check if all data is read correctly.
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949,
+ "Incorrect Z value at 0, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980,
+ "Incorrect Z value at 0, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129,
+ "Incorrect Z value at 0, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951,
+ "Incorrect Z value at 1, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117,
+ "Incorrect Z value at 1, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423,
+ "Incorrect Z value at 1, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375,
+ "Incorrect Z value at 2, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664,
+ "Incorrect Z value at 2, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264,
+ "Incorrect Z value at 2, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792,
+ "Incorrect Z value at 3, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165,
+ "Incorrect Z value at 3, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711,
+ "Incorrect Z value at 3, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489,
+ "Incorrect Z value at 4, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005,
+ "Incorrect Z value at 4, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857,
+ "Incorrect Z value at 4, 3");
+ // check if RAW = Z
+ mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer");
+
+ // check if labels read correctly
+ mu_assert(data->y[0] == 2, "Incorrect label read at 0");
+ mu_assert(data->y[1] == 1, "Incorrect label read at 1");
+ mu_assert(data->y[2] == 3, "Incorrect label read at 2");
+ mu_assert(data->y[3] == 4, "Incorrect label read at 3");
+ mu_assert(data->y[4] == 3, "Incorrect label read at 4");
+
+ // check if the column of ones is added
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1,
+ "Incorrect Z value at 0, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1,
+ "Incorrect Z value at 1, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1,
+ "Incorrect Z value at 2, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1,
+ "Incorrect Z value at 3, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1,
+ "Incorrect Z value at 4, 0");
+
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_gensvm_read_data_libsvm_sparse()
+{
+ char *filename = "./data/test_file_read_data_sparse_libsvm.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data_libsvm(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 10, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 4, "Incorrect value for K");
+
+ // check if dense data pointers are NULL
+ mu_assert(data->Z == NULL, "Z pointer isn't NULL");
+ mu_assert(data->RAW == NULL, "RAW pointer isn't NULL");
+
+ // check sparse data structure
+ mu_assert(data->spZ != NULL, "spZ is NULL");
+ mu_assert(data->spZ->nnz == 14, "Incorrect nnz");
+ mu_assert(data->spZ->n_row == 10, "Incorrect n_row");
+ mu_assert(data->spZ->n_col == 4, "Incorrect n_col");
+
+ // check sparse values
+ mu_assert(data->spZ->values[0] == 1.0,
+ "Incorrect nonzero value at 0");
+ mu_assert(data->spZ->values[1] == 0.7016517970438980,
+ "Incorrect nonzero value at 1");
+ mu_assert(data->spZ->values[2] == 1.0,
+ "Incorrect nonzero value at 2");
+ mu_assert(data->spZ->values[3] == 0.0370930278245423,
+ "Incorrect nonzero value at 3");
+ mu_assert(data->spZ->values[4] == 1.0,
+ "Incorrect nonzero value at 4");
+ mu_assert(data->spZ->values[5] == 1.0,
+ "Incorrect nonzero value at 5");
+ mu_assert(data->spZ->values[6] == 0.4853242744610165,
+ "Incorrect nonzero value at 6");
+ mu_assert(data->spZ->values[7] == 1.0,
+ "Incorrect nonzero value at 7");
+ mu_assert(data->spZ->values[8] == 0.7630904372339489,
+ "Incorrect nonzero value at 8");
+ mu_assert(data->spZ->values[9] == 1.0,
+ "Incorrect nonzero value at 9");
+ mu_assert(data->spZ->values[10] == 1.0,
+ "Incorrect nonzero value at 10");
+ mu_assert(data->spZ->values[11] == 1.0,
+ "Incorrect nonzero value at 11");
+ mu_assert(data->spZ->values[12] == 1.0,
+ "Incorrect nonzero value at 12");
+ mu_assert(data->spZ->values[13] == 1.0,
+ "Incorrect nonzero value at 13");
+
+ // check sparse row lengths
+ mu_assert(data->spZ->ia[0] == 0, "Incorrect ia value at 0");
+ mu_assert(data->spZ->ia[1] == 2, "Incorrect ia value at 1");
+ mu_assert(data->spZ->ia[2] == 4, "Incorrect ia value at 2");
+ mu_assert(data->spZ->ia[3] == 5, "Incorrect ia value at 3");
+ mu_assert(data->spZ->ia[4] == 7, "Incorrect ia value at 4");
+ mu_assert(data->spZ->ia[5] == 9, "Incorrect ia value at 5");
+ mu_assert(data->spZ->ia[6] == 10, "Incorrect ia value at 5");
+ mu_assert(data->spZ->ia[7] == 11, "Incorrect ia value at 5");
+ mu_assert(data->spZ->ia[8] == 12, "Incorrect ia value at 5");
+ mu_assert(data->spZ->ia[9] == 13, "Incorrect ia value at 5");
+ mu_assert(data->spZ->ia[10] == 14, "Incorrect ia value at 5");
+
+ // check sparse column indices
+ mu_assert(data->spZ->ja[0] == 0, "Incorrect ja value at 0");
+ mu_assert(data->spZ->ja[1] == 2, "Incorrect ja value at 1");
+ mu_assert(data->spZ->ja[2] == 0, "Incorrect ja value at 2");
+ mu_assert(data->spZ->ja[3] == 3, "Incorrect ja value at 3");
+ mu_assert(data->spZ->ja[4] == 0, "Incorrect ja value at 4");
+ mu_assert(data->spZ->ja[5] == 0, "Incorrect ja value at 5");
+ mu_assert(data->spZ->ja[6] == 2, "Incorrect ja value at 6");
+ mu_assert(data->spZ->ja[7] == 0, "Incorrect ja value at 7");
+ mu_assert(data->spZ->ja[8] == 1, "Incorrect ja value at 8");
+ mu_assert(data->spZ->ja[9] == 0, "Incorrect ja value at 7");
+ mu_assert(data->spZ->ja[10] == 0, "Incorrect ja value at 7");
+ mu_assert(data->spZ->ja[11] == 0, "Incorrect ja value at 7");
+ mu_assert(data->spZ->ja[12] == 0, "Incorrect ja value at 7");
+ mu_assert(data->spZ->ja[13] == 0, "Incorrect ja value at 7");
+
+ // check if labels read correctly
+ mu_assert(data->y[0] == 2, "Incorrect label read at 0");
+ mu_assert(data->y[1] == 1, "Incorrect label read at 1");
+ mu_assert(data->y[2] == 3, "Incorrect label read at 2");
+ mu_assert(data->y[3] == 4, "Incorrect label read at 3");
+ mu_assert(data->y[4] == 3, "Incorrect label read at 4");
+
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_gensvm_read_data_libsvm_no_label()
+{
+ char *filename = "./data/test_file_read_data_no_label_libsvm.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data_libsvm(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 5, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 0, "Incorrect value for K");
+
+ // check if all data is read correctly.
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949,
+ "Incorrect Z value at 0, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980,
+ "Incorrect Z value at 0, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129,
+ "Incorrect Z value at 0, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951,
+ "Incorrect Z value at 1, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117,
+ "Incorrect Z value at 1, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423,
+ "Incorrect Z value at 1, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375,
+ "Incorrect Z value at 2, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664,
+ "Incorrect Z value at 2, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264,
+ "Incorrect Z value at 2, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792,
+ "Incorrect Z value at 3, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165,
+ "Incorrect Z value at 3, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711,
+ "Incorrect Z value at 3, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489,
+ "Incorrect Z value at 4, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005,
+ "Incorrect Z value at 4, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857,
+ "Incorrect Z value at 4, 3");
+ // check if RAW = Z
+ mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer");
+
+ // check if labels read correctly
+ mu_assert(data->y == NULL, "Outcome pointer is not NULL");
+
+ // check if the column of ones is added
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1,
+ "Incorrect Z value at 0, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1,
+ "Incorrect Z value at 1, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1,
+ "Incorrect Z value at 2, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1,
+ "Incorrect Z value at 3, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1,
+ "Incorrect Z value at 4, 0");
+
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
char *test_gensvm_read_model()
{
struct GenModel *model = gensvm_init_model();
@@ -536,6 +851,11 @@ char *all_tests()
mu_run_test(test_gensvm_read_data);
mu_run_test(test_gensvm_read_data_sparse);
mu_run_test(test_gensvm_read_data_no_label);
+ mu_run_test(test_gensvm_read_data_libsvm);
+ mu_run_test(test_gensvm_read_data_libsvm_0based);
+ mu_run_test(test_gensvm_read_data_libsvm_sparse);
+ mu_run_test(test_gensvm_read_data_libsvm_no_label);
+
mu_run_test(test_gensvm_read_model);
mu_run_test(test_gensvm_write_model);
mu_run_test(test_gensvm_write_predictions);