diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-07 21:05:57 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-07 21:05:57 +0100 |
| commit | ab119782aca1a2eb9216cd721bca3ab9a0235911 (patch) | |
| tree | 6701667265066758f8150dc5ba29ceee0c28ca83 | |
| parent | moved check for class labels to seperate module (diff) | |
| download | gensvm-ab119782aca1a2eb9216cd721bca3ab9a0235911.tar.gz gensvm-ab119782aca1a2eb9216cd721bca3ab9a0235911.zip | |
allow datasets to be stored in libsvm/svmlight format
| -rw-r--r-- | include/gensvm_globals.h | 10 | ||||
| -rw-r--r-- | include/gensvm_io.h | 1 | ||||
| -rw-r--r-- | include/gensvm_sparse.h | 1 | ||||
| -rw-r--r-- | include/gensvm_strutil.h | 2 | ||||
| -rw-r--r-- | src/GenSVMgrid.c | 15 | ||||
| -rw-r--r-- | src/GenSVMtraintest.c | 25 | ||||
| -rw-r--r-- | src/gensvm_io.c | 263 | ||||
| -rw-r--r-- | src/gensvm_sparse.c | 28 | ||||
| -rw-r--r-- | src/gensvm_strutil.c | 112 | ||||
| -rw-r--r-- | tests/data/test_file_read_data_libsvm.txt | 5 | ||||
| -rw-r--r-- | tests/data/test_file_read_data_libsvm_0.txt | 5 | ||||
| -rw-r--r-- | tests/data/test_file_read_data_no_label_libsvm.txt | 5 | ||||
| -rw-r--r-- | tests/data/test_file_read_data_sparse_libsvm.txt | 10 | ||||
| -rw-r--r-- | tests/src/test_gensvm_io.c | 320 |
14 files changed, 783 insertions, 19 deletions
diff --git a/include/gensvm_globals.h b/include/gensvm_globals.h index 1151069..1aca458 100644 --- a/include/gensvm_globals.h +++ b/include/gensvm_globals.h @@ -39,15 +39,17 @@ #include "gensvm_memory.h" // all system libraries are included here +#include <cblas.h> +#include <ctype.h> +#include <errno.h> +#include <limits.h> +#include <math.h> #include <stdarg.h> +#include <stdbool.h> #include <stdio.h> #include <stdlib.h> -#include <stdbool.h> #include <string.h> -#include <math.h> #include <time.h> -#include <cblas.h> -#include <limits.h> // ########################### Type definitions ########################### // diff --git a/include/gensvm_io.h b/include/gensvm_io.h index afcf4a3..c18e9a2 100644 --- a/include/gensvm_io.h +++ b/include/gensvm_io.h @@ -37,6 +37,7 @@ // function declarations void gensvm_read_data(struct GenData *dataset, char *data_file); +void gensvm_read_data_libsvm(struct GenData *dataset, char *data_file); void gensvm_read_model(struct GenModel *model, char *model_filename); void gensvm_write_model(struct GenModel *model, char *output_filename); diff --git a/include/gensvm_sparse.h b/include/gensvm_sparse.h index 570e19a..f897889 100644 --- a/include/gensvm_sparse.h +++ b/include/gensvm_sparse.h @@ -71,6 +71,7 @@ struct GenSparse { struct GenSparse *gensvm_init_sparse(); void gensvm_free_sparse(struct GenSparse *sp); long gensvm_count_nnz(double *A, long rows, long cols); +bool gensvm_nnz_comparison(long nnz, long rows, long cols); bool gensvm_could_sparse(double *A, long rows, long cols); struct GenSparse *gensvm_dense_to_sparse(double *A, long rows, long cols); double *gensvm_sparse_to_dense(struct GenSparse *A); diff --git a/include/gensvm_strutil.h b/include/gensvm_strutil.h index 4e86944..1a2dccc 100644 --- a/include/gensvm_strutil.h +++ b/include/gensvm_strutil.h @@ -35,6 +35,8 @@ bool str_startswith(const char *str, const char *pre); bool str_endswith(const char *str, const char *suf); +bool str_contains_char(const char *str, const char c); +char **str_split(char *original, const char *delims, int *len_ret); void next_line(FILE *fid, char *filename); char *get_line(FILE *fid, char *filename, char *buffer); diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 3799de3..41d1070 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -72,7 +72,8 @@ void exit_with_help(char **argv) printf("Usage: %s [options] grid_file\n", argv[0]); printf("Options:\n"); printf("-h | -help : print this help.\n"); - printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-q : quiet mode (no output, not even errors!)\n"); + printf("-x : data files are in LibSVM/SVMlight format\n"); exit(EXIT_FAILURE); } @@ -97,6 +98,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; struct GenGrid *grid = gensvm_init_grid(); @@ -108,12 +110,16 @@ int main(int argc, char **argv) || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); parse_command_line(argc, argv, input_filename); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); read_grid_from_file(input_filename, grid); note("Reading data from %s\n", grid->train_data_file); - gensvm_read_data(train_data, grid->train_data_file); + if (libsvm_format) + gensvm_read_data_libsvm(train_data, grid->train_data_file); + else + gensvm_read_data(train_data, grid->train_data_file); // check labels of training data gensvm_check_outcome_contiguous(train_data); @@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename) GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]); @@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid) if (fid == NULL) { fprintf(stderr, "Error opening grid file %s\n", input_filename); - exit(1); + exit(EXIT_FAILURE); } grid->traintype = CV; while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) { diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 8bebf3a..a275e57 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -77,7 +77,6 @@ void exit_with_help(char **argv) "Huber hinge\n"); printf("-l lambda : set the value of lambda " "(lambda > 0)\n"); - printf("-s seed_model_file : use previous model as seed for V\n"); printf("-m model_output_file : write model output to file " "(not saved if no file provided)\n"); printf("-o prediction_output : write predictions of test data to " @@ -88,8 +87,11 @@ void exit_with_help(char **argv) "errors!)\n"); printf("-r rho : choose the weigth specification " "(1 = unit, 2 = group)\n"); + printf("-s seed_model_file : use previous model as seed for V\n"); printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, " "3=SIGMOID)\n"); + printf("-x : data files are in LibSVM/SVMlight " + "format\n"); printf("\n"); exit(EXIT_FAILURE); @@ -108,6 +110,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + bool libsvm_format = false; long i, *predy = NULL; double performance; @@ -130,9 +133,15 @@ int main(int argc, char **argv) parse_command_line(argc, argv, model, &model_inputfile, &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); + libsvm_format = gensvm_check_argv(argc, argv, "-x"); + + // read data from file + if (libsvm_format) + gensvm_read_data_libsvm(traindata, training_inputfile); + else + gensvm_read_data(traindata, training_inputfile); - // read data from file and check labels - gensvm_read_data(traindata, training_inputfile); + // check labels for consistency if (!gensvm_check_outcome_contiguous(traindata)) { err("[GenSVM Error]: Class labels should start from 1 and " "have no gaps. Please reformat your data.\n"); @@ -168,7 +177,11 @@ int main(int argc, char **argv) // if we also have a test set, predict labels and write to predictions // to an output file if specified if (testing_inputfile != NULL) { - gensvm_read_data(testdata, testing_inputfile); + // read the test data + if (libsvm_format) + gensvm_read_data_libsvm(testdata, testing_inputfile); + else + gensvm_read_data(testdata, testing_inputfile); // check if we are sparse and want nonlinearity if (testdata->Z == NULL && model->kerneltype != K_LINEAR) { @@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = stderr; // parse options + // note: flags that don't have an argument should decrement i for (i=1; i<argc; i++) { if (argv[i][0] != '-') break; if (++i>=argc) { @@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, GENSVM_ERROR_FILE = NULL; i--; break; + case 'x': + i--; + break; default: // this one should always print explicitly to // stderr, even if '-q' is supplied, because diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 0d14144..d41b2ef 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -6,7 +6,7 @@ * * @details * This file contains functions for reading and writing model files, and data - * files. It also contains a function for generating a string of the current + * files. It also contains a function for generating a string of the current * time, used in writing output files. * * @copyright @@ -38,7 +38,7 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are assumed to be in the interval [1 .. K], which can be + * The class labels are assumed to be in the interval [1 .. K], which can be * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct @@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) } } +/** + * @brief Print an error to the screen and exit (copied from LibSVM) + * + * @param[in] line_num line number where the error occured + * + */ +void exit_input_error(int line_num) +{ + err("[GenSVM Error]: Wrong input format on line: %i\n", line_num); + exit(EXIT_FAILURE); +} + +/** + * @brief Read data from a file in LibSVM/SVMlight format + * + * @details + * This function reads data from a file where the data is stored in + * LibSVM/SVMlight format. This is a sparse data format, which can be + * beneficial for certain applications. The advantage of having this function + * here is twofold: 1) existing datasets where data is stored in + * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets + * which are too large for memory when kept in dense format can be loaded + * efficiently into GenSVM. + * + * @note + * This code is based on the read_problem() function in the svm-train.c + * file of LibSVM. It has however been expanded to be able to handle data + * files without labels. + * + * @note + * This file tries to detect whether 1-based or 0-based indexing is used in + * the data file. By default 1-based indexing is used, but if an index is + * found with value 0, 0-based indexing is assumed. + * + * @sa + * gensvm_read_problem() + * + * @param[in] data GenData structure + * @param[in] data_file filename of the datafile + * + */ +void gensvm_read_data_libsvm(struct GenData *data, char *data_file) +{ + bool do_sparse, zero_based = false; + long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels, + min_index = 1; + int n_big, n_small, big_start; + double value; + FILE *fid = NULL; + char *label = NULL, + *endptr = NULL, + **big_parts = NULL, + **small_parts = NULL; + char buf[GENSVM_MAX_LINE_LENGTH]; + + fid = fopen(data_file, "r"); + if (fid == NULL) { + // LCOV_EXCL_START + err("[GenSVM Error]: Datafile %s could not be opened.\n", + data_file); + exit(EXIT_FAILURE); + // LCOV_EXCL_STOP + } + + // first count the number of elements + n = 0; + m = -1; + + num_labels = 0; + nnz = 0; + + while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) { + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + // record if this line has a label (first part has no colon) + num_labels += (!str_contains_char(big_parts[0], ':')); + + // check for each part if it is a index:value pair + for (i=0; i<n_big; i++) { + if (!str_contains_char(big_parts[i], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[i], ":", &n_small); + + // convert the index to a number + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // update the maximum index + m = maximum(m, index); + + // update the minimum index + min_index = minimum(min_index, index); + + // free the small parts + for (j=0; j<n_small; j++) free(small_parts[j]); + free(small_parts); + + // increment the nonzero counter + nnz++; + } + + // free the big parts + for (i=0; i<n_big; i++) { + free(big_parts[i]); + } + free(big_parts); + + // increment the number of observations + n++; + } + + // rewind the file pointer + rewind(fid); + + // check if we have enough labels + if (num_labels > 0 && num_labels != n) { + err("[GenSVM Error]: There are some lines with missing " + "labels. Please fix this before " + "continuing.\n"); + exit(EXIT_FAILURE); + } + + // don't forget the column of ones + nnz += n; + + // deal with 0-based or 1-based indexing in the LibSVM file + if (min_index == 0) { + m++; + zero_based = true; + } + + // check if sparsity is worth it + do_sparse = gensvm_nnz_comparison(nnz, n, m+1); + if (do_sparse) { + data->spZ = gensvm_init_sparse(); + data->spZ->nnz = nnz; + data->spZ->n_row = n; + data->spZ->n_col = m+1; + data->spZ->values = Calloc(double, nnz); + data->spZ->ia = Calloc(long, n+1); + data->spZ->ja = Calloc(long, nnz); + data->spZ->ia[0] = 0; + } else { + data->RAW = Calloc(double, n*(m+1)); + data->Z = data->RAW; + } + if (num_labels > 0) + data->y = Calloc(long, n); + + K = 0; + cnt = 0; + for (i=0; i<n; i++) { + fgets(buf, GENSVM_MAX_LINE_LENGTH, fid); + + // split the string in labels and/or index:value pairs + big_parts = str_split(buf, " \t", &n_big); + + big_start = 0; + // get the label from the first part if it exists + if (!str_contains_char(big_parts[0], ':')) { + label = strtok(big_parts[0], " \t\n"); + if (label == NULL) // empty line + exit_input_error(i+1); + + // convert the label part to a number exit if there + // are errors + tmp = strtol(label, &endptr, 10); + if (endptr == label || *endptr != '\0') + exit_input_error(i+1); + + // assign label to y + data->y[i] = tmp; + + // keep track of maximum K + K = maximum(K, data->y[i]); + + // increment big part index + big_start++; + } + + row_cnt = 0; + // set the first element in the row to 1 + if (do_sparse) { + data->spZ->values[cnt] = 1.0; + data->spZ->ja[cnt] = 0; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, 0, 1.0); + } + + // read the rest of the line + for (j=big_start; j<n_big; j++) { + if (!str_contains_char(big_parts[j], ':')) + continue; + + // split the index:value pair + small_parts = str_split(big_parts[j], ":", &n_small); + if (n_small != 2) + exit_input_error(n+1); + + // convert the index to a long + errno = 0; + index = strtol(small_parts[0], &endptr, 10); + + // catch conversion errors + if (endptr == small_parts[0] || errno != 0 || + *endptr != '\0') + exit_input_error(n+1); + + // convert the value to a double + errno = 0; + value = strtod(small_parts[1], &endptr); + if (endptr == small_parts[1] || errno != 0 || + (*endptr != '\0' && !isspace(*endptr))) + exit_input_error(n+1); + + if (do_sparse) { + data->spZ->values[cnt] = value; + data->spZ->ja[cnt] = index + zero_based; + cnt++; + row_cnt++; + } else { + matrix_set(data->RAW, m+1, i, + index + zero_based, value); + } + + // free the small parts + free(small_parts[0]); + free(small_parts[1]); + free(small_parts); + } + + if (do_sparse) { + data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt; + } + + // free the big parts + for (j=0; j<n_big; j++) { + free(big_parts[j]); + } + free(big_parts); + } + + fclose(fid); + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + +} /** * @brief Read model from file diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c index ce99b3b..1a9317e 100644 --- a/src/gensvm_sparse.c +++ b/src/gensvm_sparse.c @@ -91,6 +91,23 @@ long gensvm_count_nnz(double *A, long rows, long cols) } /** + * @brief Compare the number of nonzeros is such that sparsity if worth it + * + * @details + * This is a utility function, see gensvm_could_sparse() for more info. + * + * @param[in] nnz number of nonzero elements + * @param[in] rows number of rows + * @param[in] cols number of columns + * + * @return whether or not sparsity is worth it + */ +bool gensvm_nnz_comparison(long nnz, long rows, long cols) +{ + return (nnz < (rows*(cols-1.0)-1.0)/2.0); +} + +/** * @brief Check if it is worthwile to convert to a sparse matrix * * @details @@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols) * the amount of nonzero entries is small enough, the function returns the * number of nonzeros. If it is too big, it returns -1. * + * @sa + * gensvm_nnz_comparison() + * * @param[in] A matrix in dense format (RowMajor order) * @param[in] rows number of rows of A * @param[in] cols number of columns of A * - * @return + * @return whether or not sparsity is worth it */ bool gensvm_could_sparse(double *A, long rows, long cols) { long nnz = gensvm_count_nnz(A, rows, cols); - - if (nnz < (rows*(cols-1.0)-1.0)/2.0) { - return true; - } - return false; + return gensvm_nnz_comparison(nnz, rows, cols); } diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c index f6bb8f7..7bca2ca 100644 --- a/src/gensvm_strutil.c +++ b/src/gensvm_strutil.c @@ -63,6 +63,118 @@ bool str_endswith(const char *str, const char *suf) } /** + * @brief Check if a string contains a char + * + * @details + * Simple utility function to check if a char occurs in a string. + * + * @param[in] str input string + * @param[in] c character + * + * @return number of times c occurs in str + */ +bool str_contains_char(const char *str, const char c) +{ + size_t i, len = strlen(str); + for (i=0; i<len; i++) + if (str[i] == c) + return true; + return false; +} + +/** + * @brief Count the number of times a string contains any character of another + * + * @details + * This function is used to count the number of expected parts in the function + * str_split(). It counts the number of times a character from a string of + * characters is present in an input string. + * + * @param[in] str input string + * @param[in] chars characters to count + * + * @return number of times any character from chars occurs in str + * + */ +int count_str_occurrences(const char *str, const char *chars) +{ + size_t i, j, len_str = strlen(str), + len_chars = strlen(chars); + int count = 0; + for (i=0; i<len_str; i++) { + for (j=0; j<len_chars; j++) { + count += (str[i] == chars[j]); + } + } + return count; +} + +/** + * @brief Split a string on delimiters and return an array of parts + * + * @details + * This function takes as input a string and a string of delimiters. As + * output, it gives an array of the parts of the first string, splitted on the + * characters in the second string. The input string is not changed, and the + * output contains all copies of the input string parts. + * + * @note + * The code is based on: http://stackoverflow.com/a/9210560 + * + * @param[in] original string you wish to split + * @param[in] delims string with delimiters to split on + * @param[out] len_ret length of the output array + * + * @return array of string parts + */ +char **str_split(char *original, const char *delims, int *len_ret) +{ + char *copy = NULL, + *token = NULL, + **result = NULL; + int i, count; + + size_t len = strlen(original); + size_t n_delim = strlen(delims); + + // add the null terminator to the delimiters + char all_delim[1 + n_delim]; + for (i=0; i<n_delim; i++) + all_delim[i] = delims[i]; + all_delim[n_delim] = '\0'; + + // number of occurances of the delimiters + count = count_str_occurrences(original, delims); + + // extra count in case there is a delimiter at the end + count += (str_contains_char(delims, original[len - 1])); + + // extra count for the null terminator + count++; + + // allocate the result array + result = Calloc(char *, count); + + // tokenize a copy of the original string and keep the splits + i = 0; + copy = Calloc(char, len + 1); + strcpy(copy, original); + token = strtok(copy, all_delim); + while (token) { + result[i] = Calloc(char, strlen(token) + 1); + strcpy(result[i], token); + i++; + + token = strtok(NULL, all_delim); + } + free(copy); + + *len_ret = i; + + return result; +} + +/** * @brief Move to next line in file * * @param[in] fid File opened for reading diff --git a/tests/data/test_file_read_data_libsvm.txt b/tests/data/test_file_read_data_libsvm.txt new file mode 100644 index 0000000..da1c0ff --- /dev/null +++ b/tests/data/test_file_read_data_libsvm.txt @@ -0,0 +1,5 @@ +2 1:0.7065937536993949 2:0.7016517970438980 3:0.1548611397288129 +1 1:0.4604987687863951 2:0.6374142980176117 3:0.0370930278245423 +3 1:0.3798777132278375 2:0.5745070018747664 3:0.2570906697837264 +4 1:0.2789376050039792 2:0.4853242744610165 3:0.1894010436762711 +3 1:0.7630904372339489 2:0.1341546320318005 3:0.6827430912944857 diff --git a/tests/data/test_file_read_data_libsvm_0.txt b/tests/data/test_file_read_data_libsvm_0.txt new file mode 100644 index 0000000..6169a0c --- /dev/null +++ b/tests/data/test_file_read_data_libsvm_0.txt @@ -0,0 +1,5 @@ +2 0:0.7065937536993949 1:0.7016517970438980 2:0.1548611397288129 +1 0:0.4604987687863951 1:0.6374142980176117 2:0.0370930278245423 +3 0:0.3798777132278375 1:0.5745070018747664 2:0.2570906697837264 +4 0:0.2789376050039792 1:0.4853242744610165 2:0.1894010436762711 +3 0:0.7630904372339489 1:0.1341546320318005 2:0.6827430912944857 diff --git a/tests/data/test_file_read_data_no_label_libsvm.txt b/tests/data/test_file_read_data_no_label_libsvm.txt new file mode 100644 index 0000000..5452216 --- /dev/null +++ b/tests/data/test_file_read_data_no_label_libsvm.txt @@ -0,0 +1,5 @@ +1:0.7065937536993949 2:0.7016517970438980 3:0.1548611397288129 +1:0.4604987687863951 2:0.6374142980176117 3:0.0370930278245423 +1:0.3798777132278375 2:0.5745070018747664 3:0.2570906697837264 +1:0.2789376050039792 2:0.4853242744610165 3:0.1894010436762711 +1:0.7630904372339489 2:0.1341546320318005 3:0.6827430912944857 diff --git a/tests/data/test_file_read_data_sparse_libsvm.txt b/tests/data/test_file_read_data_sparse_libsvm.txt new file mode 100644 index 0000000..ff41fe7 --- /dev/null +++ b/tests/data/test_file_read_data_sparse_libsvm.txt @@ -0,0 +1,10 @@ +2 2:0.7016517970438980 +1 3:0.0370930278245423 +3 +4 2:0.4853242744610165 +3 1:0.7630904372339489 +1 +2 +3 +4 +1 diff --git a/tests/src/test_gensvm_io.c b/tests/src/test_gensvm_io.c index f1a7d22..29776d6 100644 --- a/tests/src/test_gensvm_io.c +++ b/tests/src/test_gensvm_io.c @@ -265,6 +265,321 @@ char *test_gensvm_read_data_no_label() return NULL; } +char *test_gensvm_read_data_libsvm() +{ + char *filename = "./data/test_file_read_data_libsvm.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_0based() +{ + char *filename = "./data/test_file_read_data_libsvm_0.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_sparse() +{ + char *filename = "./data/test_file_read_data_sparse_libsvm.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 10, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 4, "Incorrect value for K"); + + // check if dense data pointers are NULL + mu_assert(data->Z == NULL, "Z pointer isn't NULL"); + mu_assert(data->RAW == NULL, "RAW pointer isn't NULL"); + + // check sparse data structure + mu_assert(data->spZ != NULL, "spZ is NULL"); + mu_assert(data->spZ->nnz == 14, "Incorrect nnz"); + mu_assert(data->spZ->n_row == 10, "Incorrect n_row"); + mu_assert(data->spZ->n_col == 4, "Incorrect n_col"); + + // check sparse values + mu_assert(data->spZ->values[0] == 1.0, + "Incorrect nonzero value at 0"); + mu_assert(data->spZ->values[1] == 0.7016517970438980, + "Incorrect nonzero value at 1"); + mu_assert(data->spZ->values[2] == 1.0, + "Incorrect nonzero value at 2"); + mu_assert(data->spZ->values[3] == 0.0370930278245423, + "Incorrect nonzero value at 3"); + mu_assert(data->spZ->values[4] == 1.0, + "Incorrect nonzero value at 4"); + mu_assert(data->spZ->values[5] == 1.0, + "Incorrect nonzero value at 5"); + mu_assert(data->spZ->values[6] == 0.4853242744610165, + "Incorrect nonzero value at 6"); + mu_assert(data->spZ->values[7] == 1.0, + "Incorrect nonzero value at 7"); + mu_assert(data->spZ->values[8] == 0.7630904372339489, + "Incorrect nonzero value at 8"); + mu_assert(data->spZ->values[9] == 1.0, + "Incorrect nonzero value at 9"); + mu_assert(data->spZ->values[10] == 1.0, + "Incorrect nonzero value at 10"); + mu_assert(data->spZ->values[11] == 1.0, + "Incorrect nonzero value at 11"); + mu_assert(data->spZ->values[12] == 1.0, + "Incorrect nonzero value at 12"); + mu_assert(data->spZ->values[13] == 1.0, + "Incorrect nonzero value at 13"); + + // check sparse row lengths + mu_assert(data->spZ->ia[0] == 0, "Incorrect ia value at 0"); + mu_assert(data->spZ->ia[1] == 2, "Incorrect ia value at 1"); + mu_assert(data->spZ->ia[2] == 4, "Incorrect ia value at 2"); + mu_assert(data->spZ->ia[3] == 5, "Incorrect ia value at 3"); + mu_assert(data->spZ->ia[4] == 7, "Incorrect ia value at 4"); + mu_assert(data->spZ->ia[5] == 9, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[6] == 10, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[7] == 11, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[8] == 12, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[9] == 13, "Incorrect ia value at 5"); + mu_assert(data->spZ->ia[10] == 14, "Incorrect ia value at 5"); + + // check sparse column indices + mu_assert(data->spZ->ja[0] == 0, "Incorrect ja value at 0"); + mu_assert(data->spZ->ja[1] == 2, "Incorrect ja value at 1"); + mu_assert(data->spZ->ja[2] == 0, "Incorrect ja value at 2"); + mu_assert(data->spZ->ja[3] == 3, "Incorrect ja value at 3"); + mu_assert(data->spZ->ja[4] == 0, "Incorrect ja value at 4"); + mu_assert(data->spZ->ja[5] == 0, "Incorrect ja value at 5"); + mu_assert(data->spZ->ja[6] == 2, "Incorrect ja value at 6"); + mu_assert(data->spZ->ja[7] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[8] == 1, "Incorrect ja value at 8"); + mu_assert(data->spZ->ja[9] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[10] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[11] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[12] == 0, "Incorrect ja value at 7"); + mu_assert(data->spZ->ja[13] == 0, "Incorrect ja value at 7"); + + // check if labels read correctly + mu_assert(data->y[0] == 2, "Incorrect label read at 0"); + mu_assert(data->y[1] == 1, "Incorrect label read at 1"); + mu_assert(data->y[2] == 3, "Incorrect label read at 2"); + mu_assert(data->y[3] == 4, "Incorrect label read at 3"); + mu_assert(data->y[4] == 3, "Incorrect label read at 4"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_gensvm_read_data_libsvm_no_label() +{ + char *filename = "./data/test_file_read_data_no_label_libsvm.txt"; + struct GenData *data = gensvm_init_data(); + + // start test code // + gensvm_read_data_libsvm(data, filename); + + // check if dimensions are correctly read + mu_assert(data->n == 5, "Incorrect value for n"); + mu_assert(data->m == 3, "Incorrect value for m"); + mu_assert(data->r == 3, "Incorrect value for r"); + mu_assert(data->K == 0, "Incorrect value for K"); + + // check if all data is read correctly. + mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949, + "Incorrect Z value at 0, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980, + "Incorrect Z value at 0, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129, + "Incorrect Z value at 0, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951, + "Incorrect Z value at 1, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117, + "Incorrect Z value at 1, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423, + "Incorrect Z value at 1, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375, + "Incorrect Z value at 2, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664, + "Incorrect Z value at 2, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264, + "Incorrect Z value at 2, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792, + "Incorrect Z value at 3, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165, + "Incorrect Z value at 3, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711, + "Incorrect Z value at 3, 3"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489, + "Incorrect Z value at 4, 1"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005, + "Incorrect Z value at 4, 2"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857, + "Incorrect Z value at 4, 3"); + // check if RAW = Z + mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer"); + + // check if labels read correctly + mu_assert(data->y == NULL, "Outcome pointer is not NULL"); + + // check if the column of ones is added + mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1, + "Incorrect Z value at 0, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1, + "Incorrect Z value at 1, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1, + "Incorrect Z value at 2, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1, + "Incorrect Z value at 3, 0"); + mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1, + "Incorrect Z value at 4, 0"); + + // end test code // + + gensvm_free_data(data); + + return NULL; +} + char *test_gensvm_read_model() { struct GenModel *model = gensvm_init_model(); @@ -536,6 +851,11 @@ char *all_tests() mu_run_test(test_gensvm_read_data); mu_run_test(test_gensvm_read_data_sparse); mu_run_test(test_gensvm_read_data_no_label); + mu_run_test(test_gensvm_read_data_libsvm); + mu_run_test(test_gensvm_read_data_libsvm_0based); + mu_run_test(test_gensvm_read_data_libsvm_sparse); + mu_run_test(test_gensvm_read_data_libsvm_no_label); + mu_run_test(test_gensvm_read_model); mu_run_test(test_gensvm_write_model); mu_run_test(test_gensvm_write_predictions); |
