aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/GenSVMgrid.c15
-rw-r--r--src/GenSVMtraintest.c25
-rw-r--r--src/gensvm_io.c263
-rw-r--r--src/gensvm_sparse.c28
-rw-r--r--src/gensvm_strutil.c112
5 files changed, 428 insertions, 15 deletions
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 3799de3..41d1070 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -72,7 +72,8 @@ void exit_with_help(char **argv)
printf("Usage: %s [options] grid_file\n", argv[0]);
printf("Options:\n");
printf("-h | -help : print this help.\n");
- printf("-q : quiet mode (no output, not even errors!)\n");
+ printf("-q : quiet mode (no output, not even errors!)\n");
+ printf("-x : data files are in LibSVM/SVMlight format\n");
exit(EXIT_FAILURE);
}
@@ -97,6 +98,7 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ bool libsvm_format = false;
char input_filename[GENSVM_MAX_LINE_LENGTH];
struct GenGrid *grid = gensvm_init_grid();
@@ -108,12 +110,16 @@ int main(int argc, char **argv)
|| gensvm_check_argv_eq(argc, argv, "-h") )
exit_with_help(argv);
parse_command_line(argc, argv, input_filename);
+ libsvm_format = gensvm_check_argv(argc, argv, "-x");
note("Reading grid file\n");
read_grid_from_file(input_filename, grid);
note("Reading data from %s\n", grid->train_data_file);
- gensvm_read_data(train_data, grid->train_data_file);
+ if (libsvm_format)
+ gensvm_read_data_libsvm(train_data, grid->train_data_file);
+ else
+ gensvm_read_data(train_data, grid->train_data_file);
// check labels of training data
gensvm_check_outcome_contiguous(train_data);
@@ -196,6 +202,9 @@ void parse_command_line(int argc, char **argv, char *input_filename)
GENSVM_ERROR_FILE = NULL;
i--;
break;
+ case 'x':
+ i--;
+ break;
default:
fprintf(stderr, "Unknown option: -%c\n",
argv[i-1][1]);
@@ -267,7 +276,7 @@ void read_grid_from_file(char *input_filename, struct GenGrid *grid)
if (fid == NULL) {
fprintf(stderr, "Error opening grid file %s\n",
input_filename);
- exit(1);
+ exit(EXIT_FAILURE);
}
grid->traintype = CV;
while ( fgets(buffer, GENSVM_MAX_LINE_LENGTH, fid) != NULL ) {
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index 8bebf3a..a275e57 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -77,7 +77,6 @@ void exit_with_help(char **argv)
"Huber hinge\n");
printf("-l lambda : set the value of lambda "
"(lambda > 0)\n");
- printf("-s seed_model_file : use previous model as seed for V\n");
printf("-m model_output_file : write model output to file "
"(not saved if no file provided)\n");
printf("-o prediction_output : write predictions of test data to "
@@ -88,8 +87,11 @@ void exit_with_help(char **argv)
"errors!)\n");
printf("-r rho : choose the weigth specification "
"(1 = unit, 2 = group)\n");
+ printf("-s seed_model_file : use previous model as seed for V\n");
printf("-t type : kerneltype (0=LINEAR, 1=POLY, 2=RBF, "
"3=SIGMOID)\n");
+ printf("-x : data files are in LibSVM/SVMlight "
+ "format\n");
printf("\n");
exit(EXIT_FAILURE);
@@ -108,6 +110,7 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ bool libsvm_format = false;
long i, *predy = NULL;
double performance;
@@ -130,9 +133,15 @@ int main(int argc, char **argv)
parse_command_line(argc, argv, model, &model_inputfile,
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
+ libsvm_format = gensvm_check_argv(argc, argv, "-x");
+
+ // read data from file
+ if (libsvm_format)
+ gensvm_read_data_libsvm(traindata, training_inputfile);
+ else
+ gensvm_read_data(traindata, training_inputfile);
- // read data from file and check labels
- gensvm_read_data(traindata, training_inputfile);
+ // check labels for consistency
if (!gensvm_check_outcome_contiguous(traindata)) {
err("[GenSVM Error]: Class labels should start from 1 and "
"have no gaps. Please reformat your data.\n");
@@ -168,7 +177,11 @@ int main(int argc, char **argv)
// if we also have a test set, predict labels and write to predictions
// to an output file if specified
if (testing_inputfile != NULL) {
- gensvm_read_data(testdata, testing_inputfile);
+ // read the test data
+ if (libsvm_format)
+ gensvm_read_data_libsvm(testdata, testing_inputfile);
+ else
+ gensvm_read_data(testdata, testing_inputfile);
// check if we are sparse and want nonlinearity
if (testdata->Z == NULL && model->kerneltype != K_LINEAR) {
@@ -258,6 +271,7 @@ void parse_command_line(int argc, char **argv, struct GenModel *model,
GENSVM_ERROR_FILE = stderr;
// parse options
+ // note: flags that don't have an argument should decrement i
for (i=1; i<argc; i++) {
if (argv[i][0] != '-') break;
if (++i>=argc) {
@@ -311,6 +325,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model,
GENSVM_ERROR_FILE = NULL;
i--;
break;
+ case 'x':
+ i--;
+ break;
default:
// this one should always print explicitly to
// stderr, even if '-q' is supplied, because
diff --git a/src/gensvm_io.c b/src/gensvm_io.c
index 0d14144..d41b2ef 100644
--- a/src/gensvm_io.c
+++ b/src/gensvm_io.c
@@ -6,7 +6,7 @@
*
* @details
* This file contains functions for reading and writing model files, and data
- * files. It also contains a function for generating a string of the current
+ * files. It also contains a function for generating a string of the current
* time, used in writing output files.
*
* @copyright
@@ -38,7 +38,7 @@
* Read the data from the data_file. The data matrix X is augmented
* with a column of ones, to get the matrix Z. The data is expected
* to follow a specific format, which is specified in the @ref spec_data_file.
- * The class labels are assumed to be in the interval [1 .. K], which can be
+ * The class labels are assumed to be in the interval [1 .. K], which can be
* checked using the function gensvm_check_outcome_contiguous().
*
* @param[in,out] dataset initialized GenData struct
@@ -133,6 +133,265 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
}
}
+/**
+ * @brief Print an error to the screen and exit (copied from LibSVM)
+ *
+ * @param[in] line_num line number where the error occured
+ *
+ */
+void exit_input_error(int line_num)
+{
+ err("[GenSVM Error]: Wrong input format on line: %i\n", line_num);
+ exit(EXIT_FAILURE);
+}
+
+/**
+ * @brief Read data from a file in LibSVM/SVMlight format
+ *
+ * @details
+ * This function reads data from a file where the data is stored in
+ * LibSVM/SVMlight format. This is a sparse data format, which can be
+ * beneficial for certain applications. The advantage of having this function
+ * here is twofold: 1) existing datasets where data is stored in
+ * LibSVM/SVMlight format can be easily used in GenSVM, and 2) sparse datasets
+ * which are too large for memory when kept in dense format can be loaded
+ * efficiently into GenSVM.
+ *
+ * @note
+ * This code is based on the read_problem() function in the svm-train.c
+ * file of LibSVM. It has however been expanded to be able to handle data
+ * files without labels.
+ *
+ * @note
+ * This file tries to detect whether 1-based or 0-based indexing is used in
+ * the data file. By default 1-based indexing is used, but if an index is
+ * found with value 0, 0-based indexing is assumed.
+ *
+ * @sa
+ * gensvm_read_problem()
+ *
+ * @param[in] data GenData structure
+ * @param[in] data_file filename of the datafile
+ *
+ */
+void gensvm_read_data_libsvm(struct GenData *data, char *data_file)
+{
+ bool do_sparse, zero_based = false;
+ long i, j, n, m, K, nnz, cnt, tmp, index, row_cnt, num_labels,
+ min_index = 1;
+ int n_big, n_small, big_start;
+ double value;
+ FILE *fid = NULL;
+ char *label = NULL,
+ *endptr = NULL,
+ **big_parts = NULL,
+ **small_parts = NULL;
+ char buf[GENSVM_MAX_LINE_LENGTH];
+
+ fid = fopen(data_file, "r");
+ if (fid == NULL) {
+ // LCOV_EXCL_START
+ err("[GenSVM Error]: Datafile %s could not be opened.\n",
+ data_file);
+ exit(EXIT_FAILURE);
+ // LCOV_EXCL_STOP
+ }
+
+ // first count the number of elements
+ n = 0;
+ m = -1;
+
+ num_labels = 0;
+ nnz = 0;
+
+ while (fgets(buf, GENSVM_MAX_LINE_LENGTH, fid) != NULL) {
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ // record if this line has a label (first part has no colon)
+ num_labels += (!str_contains_char(big_parts[0], ':'));
+
+ // check for each part if it is a index:value pair
+ for (i=0; i<n_big; i++) {
+ if (!str_contains_char(big_parts[i], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[i], ":", &n_small);
+
+ // convert the index to a number
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // update the maximum index
+ m = maximum(m, index);
+
+ // update the minimum index
+ min_index = minimum(min_index, index);
+
+ // free the small parts
+ for (j=0; j<n_small; j++) free(small_parts[j]);
+ free(small_parts);
+
+ // increment the nonzero counter
+ nnz++;
+ }
+
+ // free the big parts
+ for (i=0; i<n_big; i++) {
+ free(big_parts[i]);
+ }
+ free(big_parts);
+
+ // increment the number of observations
+ n++;
+ }
+
+ // rewind the file pointer
+ rewind(fid);
+
+ // check if we have enough labels
+ if (num_labels > 0 && num_labels != n) {
+ err("[GenSVM Error]: There are some lines with missing "
+ "labels. Please fix this before "
+ "continuing.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // don't forget the column of ones
+ nnz += n;
+
+ // deal with 0-based or 1-based indexing in the LibSVM file
+ if (min_index == 0) {
+ m++;
+ zero_based = true;
+ }
+
+ // check if sparsity is worth it
+ do_sparse = gensvm_nnz_comparison(nnz, n, m+1);
+ if (do_sparse) {
+ data->spZ = gensvm_init_sparse();
+ data->spZ->nnz = nnz;
+ data->spZ->n_row = n;
+ data->spZ->n_col = m+1;
+ data->spZ->values = Calloc(double, nnz);
+ data->spZ->ia = Calloc(long, n+1);
+ data->spZ->ja = Calloc(long, nnz);
+ data->spZ->ia[0] = 0;
+ } else {
+ data->RAW = Calloc(double, n*(m+1));
+ data->Z = data->RAW;
+ }
+ if (num_labels > 0)
+ data->y = Calloc(long, n);
+
+ K = 0;
+ cnt = 0;
+ for (i=0; i<n; i++) {
+ fgets(buf, GENSVM_MAX_LINE_LENGTH, fid);
+
+ // split the string in labels and/or index:value pairs
+ big_parts = str_split(buf, " \t", &n_big);
+
+ big_start = 0;
+ // get the label from the first part if it exists
+ if (!str_contains_char(big_parts[0], ':')) {
+ label = strtok(big_parts[0], " \t\n");
+ if (label == NULL) // empty line
+ exit_input_error(i+1);
+
+ // convert the label part to a number exit if there
+ // are errors
+ tmp = strtol(label, &endptr, 10);
+ if (endptr == label || *endptr != '\0')
+ exit_input_error(i+1);
+
+ // assign label to y
+ data->y[i] = tmp;
+
+ // keep track of maximum K
+ K = maximum(K, data->y[i]);
+
+ // increment big part index
+ big_start++;
+ }
+
+ row_cnt = 0;
+ // set the first element in the row to 1
+ if (do_sparse) {
+ data->spZ->values[cnt] = 1.0;
+ data->spZ->ja[cnt] = 0;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i, 0, 1.0);
+ }
+
+ // read the rest of the line
+ for (j=big_start; j<n_big; j++) {
+ if (!str_contains_char(big_parts[j], ':'))
+ continue;
+
+ // split the index:value pair
+ small_parts = str_split(big_parts[j], ":", &n_small);
+ if (n_small != 2)
+ exit_input_error(n+1);
+
+ // convert the index to a long
+ errno = 0;
+ index = strtol(small_parts[0], &endptr, 10);
+
+ // catch conversion errors
+ if (endptr == small_parts[0] || errno != 0 ||
+ *endptr != '\0')
+ exit_input_error(n+1);
+
+ // convert the value to a double
+ errno = 0;
+ value = strtod(small_parts[1], &endptr);
+ if (endptr == small_parts[1] || errno != 0 ||
+ (*endptr != '\0' && !isspace(*endptr)))
+ exit_input_error(n+1);
+
+ if (do_sparse) {
+ data->spZ->values[cnt] = value;
+ data->spZ->ja[cnt] = index + zero_based;
+ cnt++;
+ row_cnt++;
+ } else {
+ matrix_set(data->RAW, m+1, i,
+ index + zero_based, value);
+ }
+
+ // free the small parts
+ free(small_parts[0]);
+ free(small_parts[1]);
+ free(small_parts);
+ }
+
+ if (do_sparse) {
+ data->spZ->ia[i+1] = data->spZ->ia[i] + row_cnt;
+ }
+
+ // free the big parts
+ for (j=0; j<n_big; j++) {
+ free(big_parts[j]);
+ }
+ free(big_parts);
+ }
+
+ fclose(fid);
+
+ data->n = n;
+ data->m = m;
+ data->r = m;
+ data->K = K;
+
+}
/**
* @brief Read model from file
diff --git a/src/gensvm_sparse.c b/src/gensvm_sparse.c
index ce99b3b..1a9317e 100644
--- a/src/gensvm_sparse.c
+++ b/src/gensvm_sparse.c
@@ -91,6 +91,23 @@ long gensvm_count_nnz(double *A, long rows, long cols)
}
/**
+ * @brief Compare the number of nonzeros is such that sparsity if worth it
+ *
+ * @details
+ * This is a utility function, see gensvm_could_sparse() for more info.
+ *
+ * @param[in] nnz number of nonzero elements
+ * @param[in] rows number of rows
+ * @param[in] cols number of columns
+ *
+ * @return whether or not sparsity is worth it
+ */
+bool gensvm_nnz_comparison(long nnz, long rows, long cols)
+{
+ return (nnz < (rows*(cols-1.0)-1.0)/2.0);
+}
+
+/**
* @brief Check if it is worthwile to convert to a sparse matrix
*
* @details
@@ -100,20 +117,19 @@ long gensvm_count_nnz(double *A, long rows, long cols)
* the amount of nonzero entries is small enough, the function returns the
* number of nonzeros. If it is too big, it returns -1.
*
+ * @sa
+ * gensvm_nnz_comparison()
+ *
* @param[in] A matrix in dense format (RowMajor order)
* @param[in] rows number of rows of A
* @param[in] cols number of columns of A
*
- * @return
+ * @return whether or not sparsity is worth it
*/
bool gensvm_could_sparse(double *A, long rows, long cols)
{
long nnz = gensvm_count_nnz(A, rows, cols);
-
- if (nnz < (rows*(cols-1.0)-1.0)/2.0) {
- return true;
- }
- return false;
+ return gensvm_nnz_comparison(nnz, rows, cols);
}
diff --git a/src/gensvm_strutil.c b/src/gensvm_strutil.c
index f6bb8f7..7bca2ca 100644
--- a/src/gensvm_strutil.c
+++ b/src/gensvm_strutil.c
@@ -63,6 +63,118 @@ bool str_endswith(const char *str, const char *suf)
}
/**
+ * @brief Check if a string contains a char
+ *
+ * @details
+ * Simple utility function to check if a char occurs in a string.
+ *
+ * @param[in] str input string
+ * @param[in] c character
+ *
+ * @return number of times c occurs in str
+ */
+bool str_contains_char(const char *str, const char c)
+{
+ size_t i, len = strlen(str);
+ for (i=0; i<len; i++)
+ if (str[i] == c)
+ return true;
+ return false;
+}
+
+/**
+ * @brief Count the number of times a string contains any character of another
+ *
+ * @details
+ * This function is used to count the number of expected parts in the function
+ * str_split(). It counts the number of times a character from a string of
+ * characters is present in an input string.
+ *
+ * @param[in] str input string
+ * @param[in] chars characters to count
+ *
+ * @return number of times any character from chars occurs in str
+ *
+ */
+int count_str_occurrences(const char *str, const char *chars)
+{
+ size_t i, j, len_str = strlen(str),
+ len_chars = strlen(chars);
+ int count = 0;
+ for (i=0; i<len_str; i++) {
+ for (j=0; j<len_chars; j++) {
+ count += (str[i] == chars[j]);
+ }
+ }
+ return count;
+}
+
+/**
+ * @brief Split a string on delimiters and return an array of parts
+ *
+ * @details
+ * This function takes as input a string and a string of delimiters. As
+ * output, it gives an array of the parts of the first string, splitted on the
+ * characters in the second string. The input string is not changed, and the
+ * output contains all copies of the input string parts.
+ *
+ * @note
+ * The code is based on: http://stackoverflow.com/a/9210560
+ *
+ * @param[in] original string you wish to split
+ * @param[in] delims string with delimiters to split on
+ * @param[out] len_ret length of the output array
+ *
+ * @return array of string parts
+ */
+char **str_split(char *original, const char *delims, int *len_ret)
+{
+ char *copy = NULL,
+ *token = NULL,
+ **result = NULL;
+ int i, count;
+
+ size_t len = strlen(original);
+ size_t n_delim = strlen(delims);
+
+ // add the null terminator to the delimiters
+ char all_delim[1 + n_delim];
+ for (i=0; i<n_delim; i++)
+ all_delim[i] = delims[i];
+ all_delim[n_delim] = '\0';
+
+ // number of occurances of the delimiters
+ count = count_str_occurrences(original, delims);
+
+ // extra count in case there is a delimiter at the end
+ count += (str_contains_char(delims, original[len - 1]));
+
+ // extra count for the null terminator
+ count++;
+
+ // allocate the result array
+ result = Calloc(char *, count);
+
+ // tokenize a copy of the original string and keep the splits
+ i = 0;
+ copy = Calloc(char, len + 1);
+ strcpy(copy, original);
+ token = strtok(copy, all_delim);
+ while (token) {
+ result[i] = Calloc(char, strlen(token) + 1);
+ strcpy(result[i], token);
+ i++;
+
+ token = strtok(NULL, all_delim);
+ }
+ free(copy);
+
+ *len_ret = i;
+
+ return result;
+}
+
+/**
* @brief Move to next line in file
*
* @param[in] fid File opened for reading