diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-07 16:39:55 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-07 16:39:55 +0100 |
| commit | 918d463103215207b9d9975bb6c0ea75754da6f9 (patch) | |
| tree | e3e466a3a281e223700d1ce2b06c3e5e0f99e935 | |
| parent | throw warning when using sparse matrices with kernels (diff) | |
| download | gensvm-918d463103215207b9d9975bb6c0ea75754da6f9.tar.gz gensvm-918d463103215207b9d9975bb6c0ea75754da6f9.zip | |
moved check for class labels to seperate module
| -rw-r--r-- | include/gensvm_checks.h | 36 | ||||
| -rw-r--r-- | include/gensvm_globals.h | 2 | ||||
| -rw-r--r-- | src/GenSVMgrid.c | 9 | ||||
| -rw-r--r-- | src/GenSVMtraintest.c | 10 | ||||
| -rw-r--r-- | src/gensvm_checks.c | 77 | ||||
| -rw-r--r-- | src/gensvm_io.c | 44 | ||||
| -rw-r--r-- | tests/src/test_gensvm_checks.c | 146 | ||||
| -rw-r--r-- | tests/src/test_gensvm_train.c | 6 |
8 files changed, 286 insertions, 44 deletions
diff --git a/include/gensvm_checks.h b/include/gensvm_checks.h new file mode 100644 index 0000000..08bad02 --- /dev/null +++ b/include/gensvm_checks.h @@ -0,0 +1,36 @@ +/** + * @file gensvm_checks.h + * @author G.J.J. van den Burg + * @date 2016-12-07 + * @brief Header file for gensvm_checks.c + * + * @copyright + Copyright 2016, G.J.J. van den Burg. + + This file is part of GenSVM. + + GenSVM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + GenSVM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with GenSVM. If not, see <http://www.gnu.org/licenses/>. + + */ + +#ifndef GENSVM_CHECKS_H +#define GENSVM_CHECKS_H + +// includes +#include "gensvm_base.h" + +// function declarations +bool gensvm_check_outcome_contiguous(struct GenData *data); + +#endif diff --git a/include/gensvm_globals.h b/include/gensvm_globals.h index 8ca3189..1151069 100644 --- a/include/gensvm_globals.h +++ b/include/gensvm_globals.h @@ -38,6 +38,7 @@ #include "gensvm_memory.h" +// all system libraries are included here #include <stdarg.h> #include <stdio.h> #include <stdlib.h> @@ -46,6 +47,7 @@ #include <math.h> #include <time.h> #include <cblas.h> +#include <limits.h> // ########################### Type definitions ########################### // diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 1fd7463..3799de3 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -37,6 +37,7 @@ */ +#include "gensvm_checks.h" #include "gensvm_cmdarg.h" #include "gensvm_io.h" #include "gensvm_gridsearch.h" @@ -114,6 +115,14 @@ int main(int argc, char **argv) note("Reading data from %s\n", grid->train_data_file); gensvm_read_data(train_data, grid->train_data_file); + // check labels of training data + gensvm_check_outcome_contiguous(train_data); + if (!gensvm_check_outcome_contiguous(train_data)) { + err("[GenSVM Error]: Class labels should start from 1 and " + "have no gaps. Please reformat your data.\n"); + exit(EXIT_FAILURE); + } + // check if we are sparse and want nonlinearity if (train_data->Z == NULL && grid->kerneltype != K_LINEAR) { err("[GenSVM Warning]: Sparse matrices with nonlinear kernels " diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index ecf455b..8bebf3a 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -28,6 +28,7 @@ */ +#include "gensvm_checks.h" #include "gensvm_cmdarg.h" #include "gensvm_io.h" #include "gensvm_train.h" @@ -130,8 +131,15 @@ int main(int argc, char **argv) &training_inputfile, &testing_inputfile, &model_outputfile, &prediction_outputfile); - // read data from files + // read data from file and check labels gensvm_read_data(traindata, training_inputfile); + if (!gensvm_check_outcome_contiguous(traindata)) { + err("[GenSVM Error]: Class labels should start from 1 and " + "have no gaps. Please reformat your data.\n"); + exit(EXIT_FAILURE); + } + + // save data filename to model model->data_file = Calloc(char, GENSVM_MAX_LINE_LENGTH); strcpy(model->data_file, training_inputfile); diff --git a/src/gensvm_checks.c b/src/gensvm_checks.c new file mode 100644 index 0000000..0f7c499 --- /dev/null +++ b/src/gensvm_checks.c @@ -0,0 +1,77 @@ +/** + * @file gensvm_checks.c + * @author G.J.J. van den Burg + * @date 2016-12-07 + * @brief Sanity checks used to ensure inputs are as expected + * + * @copyright + Copyright 2016, G.J.J. van den Burg. + + This file is part of GenSVM. + + GenSVM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + GenSVM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with GenSVM. If not, see <http://www.gnu.org/licenses/>. + + */ + +#include "gensvm_checks.h" + +/** + * @brief Check if the labels are contiguous on [1 .. K] + * + * @details + * The GenSVM library currently requires that the labels that are supplied in + * a dataset are contigous on the interval [1 .. K] and have no gaps. This is + * required because the dimensionality of the problem is directly related to + * the maximum class label K. This function checks if the labels are indeed in + * the desired range. + * + * @param[in] data a GenData struct with the current data + * + * @return whether the labels are contiguous or not + */ +bool gensvm_check_outcome_contiguous(struct GenData *data) +{ + bool in_uniq, is_contiguous = true; + long i, j, K = 1; + long max_y = -1, + min_y = LONG_MAX; + long *uniq_y = Calloc(long, K); + uniq_y[0] = data->y[0]; + + for (i=1; i<data->n; i++) { + in_uniq = false; + for (j=0; j<K; j++) { + if (uniq_y[j] == data->y[i]) { + in_uniq = true; + break; + } + } + + if (!in_uniq) { + uniq_y = Realloc(uniq_y, long, K+1); + uniq_y[K++] = data->y[i]; + } + + max_y = maximum(max_y, data->y[i]); + min_y = minimum(min_y, data->y[i]); + } + + if (min_y < 1 || max_y > K) { + is_contiguous = false; + } + + free(uniq_y); + + return is_contiguous; +} diff --git a/src/gensvm_io.c b/src/gensvm_io.c index 78838b1..0d14144 100644 --- a/src/gensvm_io.c +++ b/src/gensvm_io.c @@ -29,7 +29,6 @@ */ -#include <limits.h> #include "gensvm_io.h" /** @@ -39,12 +38,8 @@ * Read the data from the data_file. The data matrix X is augmented * with a column of ones, to get the matrix Z. The data is expected * to follow a specific format, which is specified in the @ref spec_data_file. - * The class labels are checked to make sure they correspond to the interval - * [1 .. K], where K is the total number of classes, without any gaps. - * - * @todo - * Make sure that this function allows datasets without class labels for - * testing. + * The class labels are assumed to be in the interval [1 .. K], which can be + * checked using the function gensvm_check_outcome_contiguous(). * * @param[in,out] dataset initialized GenData struct * @param[in] data_file filename of the data file. @@ -52,15 +47,10 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) { FILE *fid = NULL; - bool in_uniq; long i, j, n, m, nr = 0, - K = 0, - max_y = -1, - min_y = LONG_MAX; + K = 0; double value; - long *uniq_y = NULL; - char buf[GENSVM_MAX_LINE_LENGTH]; if ((fid = fopen(data_file, "r")) == NULL) { @@ -96,8 +86,6 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) dataset->y = Malloc(long, n); dataset->y[0] = value; K = 1; - uniq_y = Calloc(long, K); - uniq_y[0] = value; } else { free(dataset->y); dataset->y = NULL; @@ -112,33 +100,11 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) if (dataset->y != NULL) { nr += fscanf(fid, "%lf", &value); dataset->y[i] = (long) value; - - // this is to keep track of the unique values of y, so - // we can warn when they're not encoded correctly - in_uniq = false; - for (j=0; j<K; j++) { - if (uniq_y[j] == dataset->y[i]) - in_uniq = true; - } - if (!in_uniq) { - uniq_y = Realloc(uniq_y, long, K+1); - uniq_y[K++] = value; - } - max_y = maximum(max_y, value); - min_y = minimum(min_y, value); + K = maximum(K, dataset->y[i]); } } fclose(fid); - // Correct labels: must be in [1, K] - if (min_y < 1 || max_y > K) { - // LCOV_EXCL_START - err("[GenSVM Error]: Class labels should start from 1 and " - "have no gaps. Please reformat your data.\n"); - exit(EXIT_FAILURE); - // LCOV_EXCL_STOP - } - if (nr < n * m) { // LCOV_EXCL_START err("[GenSVM Error]: not enough data found in %s\n", @@ -165,8 +131,6 @@ void gensvm_read_data(struct GenData *dataset, char *data_file) dataset->RAW = NULL; dataset->Z = NULL; } - - free(uniq_y); } diff --git a/tests/src/test_gensvm_checks.c b/tests/src/test_gensvm_checks.c new file mode 100644 index 0000000..9b163be --- /dev/null +++ b/tests/src/test_gensvm_checks.c @@ -0,0 +1,146 @@ +/** + * @file test_gensvm_checks.c + * @author G.J.J. van den Burg + * @date 2016-12-07 + * @brief Unit tests for gensvm_checks.c functions + * + * @copyright + Copyright 2016, G.J.J. van den Burg. + + This file is part of GenSVM. + + GenSVM is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + GenSVM is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with GenSVM. If not, see <http://www.gnu.org/licenses/>. + + */ + +#include "minunit.h" +#include "gensvm_checks.h" + +char *test_check_outcome_contiguous_correct() +{ + struct GenData *data = gensvm_init_data(); + data->n = 10; + data->y = Calloc(long, data->n); + data->y[0] = 1; + data->y[1] = 2; + data->y[2] = 3; + data->y[3] = 4; + data->y[4] = 1; + data->y[5] = 1; + data->y[6] = 2; + data->y[7] = 2; + data->y[8] = 4; + data->y[9] = 3; + + // start test code // + mu_assert(gensvm_check_outcome_contiguous(data) == true, + "Incorrect check outcome for correct"); + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_check_outcome_contiguous_gap() +{ + struct GenData *data = gensvm_init_data(); + data->n = 10; + data->y = Calloc(long, data->n); + data->y[0] = 1; + data->y[1] = 2; + data->y[2] = 4; + data->y[3] = 4; + data->y[4] = 1; + data->y[5] = 1; + data->y[6] = 2; + data->y[7] = 2; + data->y[8] = 4; + data->y[9] = 4; + + // start test code // + mu_assert(gensvm_check_outcome_contiguous(data) == false, + "Incorrect check outcome for gap"); + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_check_outcome_contiguous_gaps() +{ + struct GenData *data = gensvm_init_data(); + data->n = 10; + data->y = Calloc(long, data->n); + data->y[0] = 1; + data->y[1] = 6; + data->y[2] = 4; + data->y[3] = 4; + data->y[4] = 1; + data->y[5] = 1; + data->y[6] = 6; + data->y[7] = 6; + data->y[8] = 4; + data->y[9] = 4; + + // start test code // + mu_assert(gensvm_check_outcome_contiguous(data) == false, + "Incorrect check outcome for gaps"); + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *test_check_outcome_contiguous_shift() +{ + struct GenData *data = gensvm_init_data(); + data->n = 10; + data->y = Calloc(long, data->n); + data->y[0] = 2; + data->y[1] = 3; + data->y[2] = 4; + data->y[3] = 5; + data->y[4] = 2; + data->y[5] = 3; + data->y[6] = 3; + data->y[7] = 4; + data->y[8] = 5; + data->y[9] = 5; + + // start test code // + mu_assert(gensvm_check_outcome_contiguous(data) == false, + "Incorrect check outcome for shift"); + // end test code // + + gensvm_free_data(data); + + return NULL; +} + +char *all_tests() +{ + mu_suite_start(); + + mu_run_test(test_check_outcome_contiguous_correct); + mu_run_test(test_check_outcome_contiguous_gap); + mu_run_test(test_check_outcome_contiguous_gaps); + mu_run_test(test_check_outcome_contiguous_shift); + + return NULL; +} + +RUN_TESTS(all_tests); diff --git a/tests/src/test_gensvm_train.c b/tests/src/test_gensvm_train.c index 4e3c0b4..4994eb6 100644 --- a/tests/src/test_gensvm_train.c +++ b/tests/src/test_gensvm_train.c @@ -1,8 +1,8 @@ /** - * @file test_gensvm_optimize.c + * @file test_gensvm_train.c * @author G.J.J. van den Burg - * @date 2016-09-01 - * @brief Unit tests for gensvm_optimize.c functions + * @date 2016-12-06 + * @brief Unit tests for gensvm_train.c functions * * @copyright Copyright 2016, G.J.J. van den Burg. |
