From ab119782aca1a2eb9216cd721bca3ab9a0235911 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 7 Dec 2016 21:05:57 +0100 Subject: allow datasets to be stored in libsvm/svmlight format --- include/gensvm_globals.h | 10 +- include/gensvm_io.h | 1 + include/gensvm_sparse.h | 1 + include/gensvm_strutil.h | 2 + src/GenSVMgrid.c | 15 +- src/GenSVMtraintest.c | 25 +- src/gensvm_io.c | 263 ++++++++++++++++- src/gensvm_sparse.c | 28 +- src/gensvm_strutil.c | 112 ++++++++ tests/data/test_file_read_data_libsvm.txt | 5 + tests/data/test_file_read_data_libsvm_0.txt | 5 + tests/data/test_file_read_data_no_label_libsvm.txt | 5 + tests/data/test_file_read_data_sparse_libsvm.txt | 10 + tests/src/test_gensvm_io.c | 320 +++++++++++++++++++++ 14 files changed, 783 insertions(+), 19 deletions(-) create mode 100644 tests/data/test_file_read_data_libsvm.txt create mode 100644 tests/data/test_file_read_data_libsvm_0.txt create mode 100644 tests/data/test_file_read_data_no_label_libsvm.txt create mode 100644 tests/data/test_file_read_data_sparse_libsvm.txt diff --git a/include/gensvm_globals.h b/include/gensvm_globals.h index 1151069..1aca458 100644 --- a/include/gensvm_globals.h +++ b/include/gensvm_globals.h @@ -39,15 +39,17 @@ #include "gensvm_memory.h" // all system libraries are included here +#include +#include +#include +#include +#include #include +#include #include #include -#include #include -#include #include -#include -#include // ########################### Type definitions ########################### // diff --git a/include/gensvm_io.h b/include/gensvm_io.h index afcf4a3..c18e9a2 100644 --- a/include/gensvm_io.h +++ b/include/gensvm_io.h @@ -37,6 +37,7 @@ // function declarations void gensvm_read_data(struct GenData *dataset, char *data_file); +void gensvm_read_data_libsvm(struct GenData *dataset, char *data_file); void gensvm_read_model(struct GenModel *model, char *model_filename); void gensvm_write_model(struct GenModel *model, char *output_filename); diff --git a/include/gensvm_sparse.h b/include/gensvm_sparse.h index 570e19a..f897889 100644 --- a/include/gensvm_sparse.h +++ b/include/gensvm_sparse.h @@ -71,6 +71,7 @@ struct GenSparse { struct GenSparse *gensvm_init_sparse(); void gensvm_free_sparse(struct GenSparse *sp); long gensvm_count_nnz(double *A, long rows, long cols); +bool gensvm_nnz_comparison(long nnz, long rows, long cols); bool gensvm_could_sparse(double *A, long rows, long cols); struct GenSparse *gensvm_dense_to_sparse(double *A, long rows, long cols); double *gensvm_sparse_to_dense(struct GenSparse *A); diff --git a/include/gensvm_strutil.h b/include/gensvm_strutil.h index 4e86944..1a2dccc 100644 --- a/include/gensvm_strutil.h +++ b/include/gensvm_strutil.h @@ -35,6 +35,8 @@ bool str_startswith(const char *str, const char *pre); bool str_endswith(const char *str, const char *suf); +bool str_contains_char(const char *str, const char c); +char **str_split(char *original, const char *delims, int *len_ret); void next_line(FILE *fid, char *filename); char *get_line(FILE *fid, char *filename, char *buffer); diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 3799de3..41d1070 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -72,7 +72,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); - printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-x : data files are in LibSVM/SVMlight format\n"); exit(EXIT_FAILURE); } @@ -97,6 +98,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; struct GenGrid *grid = gensvm_init_grid(); @@ -108,12 +110,16 @@ int main(int argc, char **argv) || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); parse_command_line(argc, argv, input_filename); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); read_grid_from_file(input_filename, grid); note("Reading data from %s\n", grid->train_data_file); - gensvm_read_data(train_data, grid->train_data_file); + if (libsvm_format) + gensvm_read_data_libsvm(train_data, grid->train_data_file); + else + gensvm_read_data(train_data, grid->train_data_file); // check labels of training data gensvm_check_outcome_contiguous(train_data); @@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename) GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]); @@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) if (fid == NULL) { fprintf(stderr, "Error opening grid file %s\n", input_filename); - exit(1); + exit(EXIT_FAILURE); } grid->traintype = CV; while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) { diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 8bebf3a..a275e57 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -77,7 +77,6 @@ void exit_with_help(char **argv) "Huber hinge\n"); printf("-l lambda : set the value of lambda " "(lambda > 0)\n"); - printf("-s seed_model_file : use previous model as seed for V\n"); printf("-m model_output_file : write model output to file " "(not saved if no file provided)\n"); printf("-o prediction_output : write predictions of test data to " @@ -88,8 +87,11 @@ void exit_with_help(char **argv) "errors!)\n"); printf("-r rho : choose the weigth specification " "(1 = unit, 2 = group)\n"); + printf("-s seed_model_file : use previous model as seed for V\n"); printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, " "3=SIGMOID)\n"); + printf("-x : data files are in LibSVM/SVMlight " + "format\n"); printf("\n"); exit(EXIT_FAILURE); @@ -108,6 +110,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; long i, *predy = NULL; double performance; @@ -130,9 +133,15 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); + + // read data from file + if (libsvm_format) + gensvm_read_data_libsvm(traindata, training_inputfile); + else + gensvm_read_data(traindata, training_inputfile); - // read data from file and check labels - gensvm_read_data(traindata, training_inputfile); + // check labels for consistency if (!gensvm_check_outcome_contiguous(traindata)) { err("[GenSVM Error]: Class labels should start from 1 and " "have no gaps. Please reformat your data.\n"); @@ -168,7 +177,11 @@ int main(int argc, char **argv) // if we also have a test set, predict labels and write to predictions // to an output file if specified if (testing_inputfile != NULL) { - gensvm_read_data(testdata, testing_inputfile); + // read the test data + if (libsvm_format) + gensvm_read_data_libsvm(testdata, testing_inputfile); + else + gensvm_read_data(testdata, testing_inputfile); // check if we are sparse and want nonlinearity if (testdata->Z == NULL && model->kerneltype != K_LINEAR) { @@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = stderr; // parse options + // note: flags that don't have an argument should decrement i for (i=1; i=argc) { @@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: // this one should always print explicitly to // stderr, even if '-q' is supplied, because diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 0d14144..d41b2ef 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -6,7 +6,7 @@ * * @details * This file contains functions for reading and writing model files, and data - * files. It also contains a function for generating a string of the current + * files. It also contains a function for generating a string of the current * time, used in writing output files. * * @copyright @@ -38,7 +38,7 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are assumed to be in the interval [1 .. K], which can be + * The class labels are assumed to be in the interval [1 .. K], which can be * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct @@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) } } +/** + * @brief Print an error to the screen and exit (copied from LibSVM) + * + * @param[in] line_num line number where the error occured + * + */ +void exit_input_error(int line_num) +{ + err("[GenSVM Error]: Wrong input format on line: %i\n", line_num); + exit(EXIT_FAILURE); +} + +/** + * @brief Read data from a file in LibSVM/SVMlight format + * + * @details + * This function reads data from a file where the data is stored in + * LibSVM/SVMlight format. This is a sparse data format, which can be + * beneficial for certain applications. The advantage of having this function + * here is twofold: 1) existing datasets where data is stored in + * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets + * which are too large for memory when kept in dense format can be loaded + * efficiently into GenSVM. + * + * @note + * This code is based on the read_problem() function in the svm-train.c + * file of LibSVM. It has however been expanded to be able to handle data + * files without labels. + * + * @note + * This file tries to detect whether 1-based or 0-based indexing is used in + * the data file. By default 1-based indexing is used, but if an index is + * found with value 0, 0-based indexing is assumed. + * + * @sa + * gensvm_read_problem() + * + * @param[in] data GenData structure + * @param[in] data_file filename of the datafile + * + */ +void gensvm_read_data_libsvm(struct GenData *data, char *data_file) +{ + bool do_sparse, zero_based = false; + long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels, + min_index = 1; + int n_big, n_small, big_start; + double value; + FILE *fid = NULL; + char *label = NULL, + *endptr = NULL, + **big_parts = NULL, + **small_parts = NULL; + char buf[GENSVM_MAX_LINE_LENGTH]; + + fid = fopen(data_file, "r"); + if (fid == NULL) { + // LCOV_EXCL_START + err("[GenSVM Error]: Datafile %s could not be opened.\n", + data_file); + exit(EXIT_FAILURE); + // LCOV_EXCL_STOP + } + + // first count the number of elements + n = 0; + m = -1; + + num_labels = 0; + nnz = 0; + + while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) { + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + // record if this line has a label (first part has no colon) + num_labels += (!str_contains_char(big_parts[0], ':')); + + // check for each part if it is a index:value pair + for (i=0; i 0 && num_labels != n) { + err("[GenSVM Error]: There are some lines with missing " + "labels. Please fix this before " + "continuing.\n"); + exit(EXIT_FAILURE); + } + + // don't forget the column of ones + nnz += n; + + // deal with 0-based or 1-based indexing in the LibSVM file + if (min_index == 0) { + m++; + zero_based = true; + } + + // check if sparsity is worth it + do_sparse = gensvm_nnz_comparison(nnz, n, m+1); + if (do_sparse) { + data->spZ = gensvm_init_sparse(); + data->spZ->nnz = nnz; + data->spZ->n_row = n; + data->spZ->n_col = m+1; + data->spZ->values = Calloc(double, nnz); + data->spZ->ia = Calloc(long, n+1); + data->spZ->ja = Calloc(long, nnz); + data->spZ->ia[0] = 0; + } else { + data->RAW = Calloc(double, n*(m+1)); + data->Z = data->RAW; + } + if (num_labels > 0) + data->y = Calloc(long, n); + + K = 0; + cnt = 0; + for (i=0; iy[i] = tmp; + + // keep track of maximum K + K = maximum(K, data->y[i]); + + // increment big part index + big_start++; + } + + row_cnt = 0; + // set the first element in the row to 1 + if (do_sparse) { + data->spZ->values[cnt] = 1.0; + data->spZ->ja[cnt] = 0; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, 0, 1.0); + } + + // read the rest of the line + for (j=big_start; jspZ->values[cnt] = value; + data->spZ->ja[cnt] = index + zero_based; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, + index + zero_based, value); + } + + // free the small parts + free(small_parts[0]); + free(small_parts[1]); + free(small_parts); + } + + if (do_sparse) { + data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt; + } + + // free the big parts + for (j=0; jn = n; + data->m = m; + data->r = m; + data->K = K; + +} /** * @brief Read model from file diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c index ce99b3b..1a9317e 100644 --- a/src/gensvm_sparse.c +++ b/src/gensvm_sparse.c @@ -90,6 +90,23 @@ long gensvm_count_nnz(double *A, long rows, long cols) return nnz; } +/** + * @brief Compare the number of nonzeros is such that sparsity if worth it + * + * @details + * This is a utility function, see gensvm_could_sparse() for more info. + * + * @param[in] nnz number of nonzero elements + * @param[in] rows number of rows + * @param[in] cols number of columns + * + * @return whether or not sparsity is worth it + */ +bool gensvm_nnz_comparison(long nnz, long rows, long cols) +{ + return (nnz < (rows*(cols-1.0)-1.0)/2.0); +} + /** * @brief Check if it is worthwile to convert to a sparse matrix * @@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols) * the amount of nonzero entries is small enough, the function returns the * number of nonzeros. If it is too big, it returns -1. * + * @sa + * gensvm_nnz_comparison() + * * @param[in] A matrix in dense format (RowMajor order) * @param[in] rows number of rows of A * @param[in] cols number of columns of A * - * @return + * @return whether or not sparsity is worth it */ bool gensvm_could_sparse(double *A, long rows, long cols) { long nnz = gensvm_count_nnz(A, rows, cols); - - if (nnz < (rows*(cols-1.0)-1.0)/2.0) { - return true; - } - return false; + return gensvm_nnz_comparison(nnz, rows, cols); } diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c index f6bb8f7..7bca2ca 100644 --- a/src/gensvm_strutil.c +++ b/src/gensvm_strutil.c @@ -62,6 +62,118 @@ bool str_endswith(const char *str, const char *suf) lensuf) == 0; } +/** + * @brief Check if a string contains a char + * + * @details + * Simple utility function to check if a char occurs in a string. + * + * @param[in] str input string + * @param[in] c character + * + * @return number of times c occurs in str + */ +bool str_contains_char(const char *str, const char c) +{ + size_t i, len = strlen(str); + for (i=0; in == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_0based() +{ + char *filename = "./data/test_file_read_data_libsvm_0.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_sparse() +{ + char *filename = "./data/test_file_read_data_sparse_libsvm.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 10, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if dense data pointers are NULL + mu_assert(data->Z == NULL, "Z pointer isn't NULL"); + mu_assert(data->RAW == NULL, "RAW pointer isn't NULL"); + + // check sparse data structure + mu_assert(data->spZ != NULL, "spZ is NULL"); + mu_assert(data->spZ->nnz == 14, "Incorrect nnz"); + mu_assert(data->spZ->n_row == 10, "Incorrect n_row"); + mu_assert(data->spZ->n_col == 4, "Incorrect n_col"); + + // check sparse values + mu_assert(data->spZ->values[0] == 1.0, + "Incorrect nonzero value at 0"); + mu_assert(data->spZ->values[1] == 0.7016517970438980, + "Incorrect nonzero value at 1"); + mu_assert(data->spZ->values[2] == 1.0, + "Incorrect nonzero value at 2"); + mu_assert(data->spZ->values[3] == 0.0370930278245423, + "Incorrect nonzero value at 3"); + mu_assert(data->spZ->values[4] == 1.0, + "Incorrect nonzero value at 4"); + mu_assert(data->spZ->values[5] == 1.0, + "Incorrect nonzero value at 5"); + mu_assert(data->spZ->values[6] == 0.4853242744610165, + "Incorrect nonzero value at 6"); + mu_assert(data->spZ->values[7] == 1.0, + "Incorrect nonzero value at 7"); + mu_assert(data->spZ->values[8] == 0.7630904372339489, + "Incorrect nonzero value at 8"); + mu_assert(data->spZ->values[9] == 1.0, + "Incorrect nonzero value at 9"); + mu_assert(data->spZ->values[10] == 1.0, + "Incorrect nonzero value at 10"); + mu_assert(data->spZ->values[11] == 1.0, + "Incorrect nonzero value at 11"); + mu_assert(data->spZ->values[12] == 1.0, + "Incorrect nonzero value at 12"); + mu_assert(data->spZ->values[13] == 1.0, + "Incorrect nonzero value at 13"); + + // check sparse row lengths + mu_assert(data->spZ->ia[0] == 0, "Incorrect ia value at 0"); + mu_assert(data->spZ->ia[1] == 2, "Incorrect ia value at 1"); + mu_assert(data->spZ->ia[2] == 4, "Incorrect ia value at 2"); + mu_assert(data->spZ->ia[3] == 5, "Incorrect ia value at 3"); + mu_assert(data->spZ->ia[4] == 7, "Incorrect ia value at 4"); + mu_assert(data->spZ->ia[5] == 9, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[6] == 10, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[7] == 11, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[8] == 12, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[9] == 13, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[10] == 14, "Incorrect ia value at 5"); + + // check sparse column indices + mu_assert(data->spZ->ja[0] == 0, "Incorrect ja value at 0"); + mu_assert(data->spZ->ja[1] == 2, "Incorrect ja value at 1"); + mu_assert(data->spZ->ja[2] == 0, "Incorrect ja value at 2"); + mu_assert(data->spZ->ja[3] == 3, "Incorrect ja value at 3"); + mu_assert(data->spZ->ja[4] == 0, "Incorrect ja value at 4"); + mu_assert(data->spZ->ja[5] == 0, "Incorrect ja value at 5"); + mu_assert(data->spZ->ja[6] == 2, "Incorrect ja value at 6"); + mu_assert(data->spZ->ja[7] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[8] == 1, "Incorrect ja value at 8"); + mu_assert(data->spZ->ja[9] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[10] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[11] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[12] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[13] == 0, "Incorrect ja value at 7"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_no_label() +{ + char *filename = "./data/test_file_read_data_no_label_libsvm.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 0, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y == NULL, "Outcome pointer is not NULL"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + char *test_gensvm_read_model() { struct GenModel *model = gensvm_init_model(); @@ -536,6 +851,11 @@ char *all_tests() mu_run_test(test_gensvm_read_data); mu_run_test(test_gensvm_read_data_sparse); mu_run_test(test_gensvm_read_data_no_label); + mu_run_test(test_gensvm_read_data_libsvm); + mu_run_test(test_gensvm_read_data_libsvm_0based); + mu_run_test(test_gensvm_read_data_libsvm_sparse); + mu_run_test(test_gensvm_read_data_libsvm_no_label); + mu_run_test(test_gensvm_read_model); mu_run_test(test_gensvm_write_model); mu_run_test(test_gensvm_write_predictions); -- cgit v1.2.3