aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2016-12-07 16:39:55 +0100
committerGertjan van den Burg <burg@ese.eur.nl>2016-12-07 16:39:55 +0100
commit918d463103215207b9d9975bb6c0ea75754da6f9 (patch)
treee3e466a3a281e223700d1ce2b06c3e5e0f99e935
parentthrow warning when using sparse matrices with kernels (diff)
downloadgensvm-918d463103215207b9d9975bb6c0ea75754da6f9.tar.gz
gensvm-918d463103215207b9d9975bb6c0ea75754da6f9.zip
moved check for class labels to seperate module
-rw-r--r--include/gensvm_checks.h36
-rw-r--r--include/gensvm_globals.h2
-rw-r--r--src/GenSVMgrid.c9
-rw-r--r--src/GenSVMtraintest.c10
-rw-r--r--src/gensvm_checks.c77
-rw-r--r--src/gensvm_io.c44
-rw-r--r--tests/src/test_gensvm_checks.c146
-rw-r--r--tests/src/test_gensvm_train.c6
8 files changed, 286 insertions, 44 deletions
diff --git a/include/gensvm_checks.h b/include/gensvm_checks.h
new file mode 100644
index 0000000..08bad02
--- /dev/null
+++ b/include/gensvm_checks.h
@@ -0,0 +1,36 @@
+/**
+ * @file gensvm_checks.h
+ * @author G.J.J. van den Burg
+ * @date 2016-12-07
+ * @brief Header file for gensvm_checks.c
+ *
+ * @copyright
+ Copyright 2016, G.J.J. van den Burg.
+
+ This file is part of GenSVM.
+
+ GenSVM is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ GenSVM is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
+
+ */
+
+#ifndef GENSVM_CHECKS_H
+#define GENSVM_CHECKS_H
+
+// includes
+#include "gensvm_base.h"
+
+// function declarations
+bool gensvm_check_outcome_contiguous(struct GenData *data);
+
+#endif
diff --git a/include/gensvm_globals.h b/include/gensvm_globals.h
index 8ca3189..1151069 100644
--- a/include/gensvm_globals.h
+++ b/include/gensvm_globals.h
@@ -38,6 +38,7 @@
#include "gensvm_memory.h"
+// all system libraries are included here
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
@@ -46,6 +47,7 @@
#include <math.h>
#include <time.h>
#include <cblas.h>
+#include <limits.h>
// ########################### Type definitions ########################### //
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 1fd7463..3799de3 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -37,6 +37,7 @@
*/
+#include "gensvm_checks.h"
#include "gensvm_cmdarg.h"
#include "gensvm_io.h"
#include "gensvm_gridsearch.h"
@@ -114,6 +115,14 @@ int main(int argc, char **argv)
note("Reading data from %s\n", grid->train_data_file);
gensvm_read_data(train_data, grid->train_data_file);
+ // check labels of training data
+ gensvm_check_outcome_contiguous(train_data);
+ if (!gensvm_check_outcome_contiguous(train_data)) {
+ err("[GenSVM Error]: Class labels should start from 1 and "
+ "have no gaps. Please reformat your data.\n");
+ exit(EXIT_FAILURE);
+ }
+
// check if we are sparse and want nonlinearity
if (train_data->Z == NULL && grid->kerneltype != K_LINEAR) {
err("[GenSVM Warning]: Sparse matrices with nonlinear kernels "
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index ecf455b..8bebf3a 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -28,6 +28,7 @@
*/
+#include "gensvm_checks.h"
#include "gensvm_cmdarg.h"
#include "gensvm_io.h"
#include "gensvm_train.h"
@@ -130,8 +131,15 @@ int main(int argc, char **argv)
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
- // read data from files
+ // read data from file and check labels
gensvm_read_data(traindata, training_inputfile);
+ if (!gensvm_check_outcome_contiguous(traindata)) {
+ err("[GenSVM Error]: Class labels should start from 1 and "
+ "have no gaps. Please reformat your data.\n");
+ exit(EXIT_FAILURE);
+ }
+
+ // save data filename to model
model->data_file = Calloc(char, GENSVM_MAX_LINE_LENGTH);
strcpy(model->data_file, training_inputfile);
diff --git a/src/gensvm_checks.c b/src/gensvm_checks.c
new file mode 100644
index 0000000..0f7c499
--- /dev/null
+++ b/src/gensvm_checks.c
@@ -0,0 +1,77 @@
+/**
+ * @file gensvm_checks.c
+ * @author G.J.J. van den Burg
+ * @date 2016-12-07
+ * @brief Sanity checks used to ensure inputs are as expected
+ *
+ * @copyright
+ Copyright 2016, G.J.J. van den Burg.
+
+ This file is part of GenSVM.
+
+ GenSVM is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ GenSVM is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
+
+ */
+
+#include "gensvm_checks.h"
+
+/**
+ * @brief Check if the labels are contiguous on [1 .. K]
+ *
+ * @details
+ * The GenSVM library currently requires that the labels that are supplied in
+ * a dataset are contigous on the interval [1 .. K] and have no gaps. This is
+ * required because the dimensionality of the problem is directly related to
+ * the maximum class label K. This function checks if the labels are indeed in
+ * the desired range.
+ *
+ * @param[in] data a GenData struct with the current data
+ *
+ * @return whether the labels are contiguous or not
+ */
+bool gensvm_check_outcome_contiguous(struct GenData *data)
+{
+ bool in_uniq, is_contiguous = true;
+ long i, j, K = 1;
+ long max_y = -1,
+ min_y = LONG_MAX;
+ long *uniq_y = Calloc(long, K);
+ uniq_y[0] = data->y[0];
+
+ for (i=1; i<data->n; i++) {
+ in_uniq = false;
+ for (j=0; j<K; j++) {
+ if (uniq_y[j] == data->y[i]) {
+ in_uniq = true;
+ break;
+ }
+ }
+
+ if (!in_uniq) {
+ uniq_y = Realloc(uniq_y, long, K+1);
+ uniq_y[K++] = data->y[i];
+ }
+
+ max_y = maximum(max_y, data->y[i]);
+ min_y = minimum(min_y, data->y[i]);
+ }
+
+ if (min_y < 1 || max_y > K) {
+ is_contiguous = false;
+ }
+
+ free(uniq_y);
+
+ return is_contiguous;
+}
diff --git a/src/gensvm_io.c b/src/gensvm_io.c
index 78838b1..0d14144 100644
--- a/src/gensvm_io.c
+++ b/src/gensvm_io.c
@@ -29,7 +29,6 @@
*/
-#include <limits.h>
#include "gensvm_io.h"
/**
@@ -39,12 +38,8 @@
* Read the data from the data_file. The data matrix X is augmented
* with a column of ones, to get the matrix Z. The data is expected
* to follow a specific format, which is specified in the @ref spec_data_file.
- * The class labels are checked to make sure they correspond to the interval
- * [1 .. K], where K is the total number of classes, without any gaps.
- *
- * @todo
- * Make sure that this function allows datasets without class labels for
- * testing.
+ * The class labels are assumed to be in the interval [1 .. K], which can be
+ * checked using the function gensvm_check_outcome_contiguous().
*
* @param[in,out] dataset initialized GenData struct
* @param[in] data_file filename of the data file.
@@ -52,15 +47,10 @@
void gensvm_read_data(struct GenData *dataset, char *data_file)
{
FILE *fid = NULL;
- bool in_uniq;
long i, j, n, m,
nr = 0,
- K = 0,
- max_y = -1,
- min_y = LONG_MAX;
+ K = 0;
double value;
- long *uniq_y = NULL;
-
char buf[GENSVM_MAX_LINE_LENGTH];
if ((fid = fopen(data_file, "r")) == NULL) {
@@ -96,8 +86,6 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
dataset->y = Malloc(long, n);
dataset->y[0] = value;
K = 1;
- uniq_y = Calloc(long, K);
- uniq_y[0] = value;
} else {
free(dataset->y);
dataset->y = NULL;
@@ -112,33 +100,11 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
if (dataset->y != NULL) {
nr += fscanf(fid, "%lf", &value);
dataset->y[i] = (long) value;
-
- // this is to keep track of the unique values of y, so
- // we can warn when they're not encoded correctly
- in_uniq = false;
- for (j=0; j<K; j++) {
- if (uniq_y[j] == dataset->y[i])
- in_uniq = true;
- }
- if (!in_uniq) {
- uniq_y = Realloc(uniq_y, long, K+1);
- uniq_y[K++] = value;
- }
- max_y = maximum(max_y, value);
- min_y = minimum(min_y, value);
+ K = maximum(K, dataset->y[i]);
}
}
fclose(fid);
- // Correct labels: must be in [1, K]
- if (min_y < 1 || max_y > K) {
- // LCOV_EXCL_START
- err("[GenSVM Error]: Class labels should start from 1 and "
- "have no gaps. Please reformat your data.\n");
- exit(EXIT_FAILURE);
- // LCOV_EXCL_STOP
- }
-
if (nr < n * m) {
// LCOV_EXCL_START
err("[GenSVM Error]: not enough data found in %s\n",
@@ -165,8 +131,6 @@ void gensvm_read_data(struct GenData *dataset, char *data_file)
dataset->RAW = NULL;
dataset->Z = NULL;
}
-
- free(uniq_y);
}
diff --git a/tests/src/test_gensvm_checks.c b/tests/src/test_gensvm_checks.c
new file mode 100644
index 0000000..9b163be
--- /dev/null
+++ b/tests/src/test_gensvm_checks.c
@@ -0,0 +1,146 @@
+/**
+ * @file test_gensvm_checks.c
+ * @author G.J.J. van den Burg
+ * @date 2016-12-07
+ * @brief Unit tests for gensvm_checks.c functions
+ *
+ * @copyright
+ Copyright 2016, G.J.J. van den Burg.
+
+ This file is part of GenSVM.
+
+ GenSVM is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ GenSVM is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with GenSVM. If not, see <http://www.gnu.org/licenses/>.
+
+ */
+
+#include "minunit.h"
+#include "gensvm_checks.h"
+
+char *test_check_outcome_contiguous_correct()
+{
+ struct GenData *data = gensvm_init_data();
+ data->n = 10;
+ data->y = Calloc(long, data->n);
+ data->y[0] = 1;
+ data->y[1] = 2;
+ data->y[2] = 3;
+ data->y[3] = 4;
+ data->y[4] = 1;
+ data->y[5] = 1;
+ data->y[6] = 2;
+ data->y[7] = 2;
+ data->y[8] = 4;
+ data->y[9] = 3;
+
+ // start test code //
+ mu_assert(gensvm_check_outcome_contiguous(data) == true,
+ "Incorrect check outcome for correct");
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_check_outcome_contiguous_gap()
+{
+ struct GenData *data = gensvm_init_data();
+ data->n = 10;
+ data->y = Calloc(long, data->n);
+ data->y[0] = 1;
+ data->y[1] = 2;
+ data->y[2] = 4;
+ data->y[3] = 4;
+ data->y[4] = 1;
+ data->y[5] = 1;
+ data->y[6] = 2;
+ data->y[7] = 2;
+ data->y[8] = 4;
+ data->y[9] = 4;
+
+ // start test code //
+ mu_assert(gensvm_check_outcome_contiguous(data) == false,
+ "Incorrect check outcome for gap");
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_check_outcome_contiguous_gaps()
+{
+ struct GenData *data = gensvm_init_data();
+ data->n = 10;
+ data->y = Calloc(long, data->n);
+ data->y[0] = 1;
+ data->y[1] = 6;
+ data->y[2] = 4;
+ data->y[3] = 4;
+ data->y[4] = 1;
+ data->y[5] = 1;
+ data->y[6] = 6;
+ data->y[7] = 6;
+ data->y[8] = 4;
+ data->y[9] = 4;
+
+ // start test code //
+ mu_assert(gensvm_check_outcome_contiguous(data) == false,
+ "Incorrect check outcome for gaps");
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_check_outcome_contiguous_shift()
+{
+ struct GenData *data = gensvm_init_data();
+ data->n = 10;
+ data->y = Calloc(long, data->n);
+ data->y[0] = 2;
+ data->y[1] = 3;
+ data->y[2] = 4;
+ data->y[3] = 5;
+ data->y[4] = 2;
+ data->y[5] = 3;
+ data->y[6] = 3;
+ data->y[7] = 4;
+ data->y[8] = 5;
+ data->y[9] = 5;
+
+ // start test code //
+ mu_assert(gensvm_check_outcome_contiguous(data) == false,
+ "Incorrect check outcome for shift");
+ // end test code //
+
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+
+ mu_run_test(test_check_outcome_contiguous_correct);
+ mu_run_test(test_check_outcome_contiguous_gap);
+ mu_run_test(test_check_outcome_contiguous_gaps);
+ mu_run_test(test_check_outcome_contiguous_shift);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_train.c b/tests/src/test_gensvm_train.c
index 4e3c0b4..4994eb6 100644
--- a/tests/src/test_gensvm_train.c
+++ b/tests/src/test_gensvm_train.c
@@ -1,8 +1,8 @@
/**
- * @file test_gensvm_optimize.c
+ * @file test_gensvm_train.c
* @author G.J.J. van den Burg
- * @date 2016-09-01
- * @brief Unit tests for gensvm_optimize.c functions
+ * @date 2016-12-06
+ * @brief Unit tests for gensvm_train.c functions
*
* @copyright
Copyright 2016, G.J.J. van den Burg.