aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2016-09-20 16:46:41 +0200
committerGertjan van den Burg <burg@ese.eur.nl>2016-09-20 16:46:41 +0200
commit28497ecc371fd4c731892de91b07f635f1020452 (patch)
tree6c1876f3e64e7d0d310f29d4f20fbf69d96f3ef4 /tests
parentMinor improvements and fixes (diff)
downloadgensvm-28497ecc371fd4c731892de91b07f635f1020452.tar.gz
gensvm-28497ecc371fd4c731892de91b07f635f1020452.zip
Unit tests and corresponding data
Diffstat (limited to 'tests')
-rw-r--r--tests/data/test_debug_print.txt4
-rw-r--r--tests/data/test_file_read_data.txt7
-rw-r--r--tests/data/test_file_read_data_no_label.txt7
-rw-r--r--tests/data/test_read_model.txt20
-rw-r--r--tests/data/test_write_model.txt20
-rw-r--r--tests/data/test_write_predictions.txt7
-rw-r--r--tests/src/test_gensvm_cv_util.c257
-rw-r--r--tests/src/test_gensvm_debug.c60
-rw-r--r--tests/src/test_gensvm_init.c215
-rw-r--r--tests/src/test_gensvm_io.c431
-rw-r--r--tests/src/test_gensvm_optimize.c450
-rw-r--r--tests/src/test_gensvm_pred.c165
-rw-r--r--tests/src/test_gensvm_sv.c62
13 files changed, 1705 insertions, 0 deletions
diff --git a/tests/data/test_debug_print.txt b/tests/data/test_debug_print.txt
new file mode 100644
index 0000000..9ef2b13
--- /dev/null
+++ b/tests/data/test_debug_print.txt
@@ -0,0 +1,4 @@
+-0.241053 -0.599809
++0.893318 -0.344058
++0.933948 -0.474352
+
diff --git a/tests/data/test_file_read_data.txt b/tests/data/test_file_read_data.txt
new file mode 100644
index 0000000..49ec392
--- /dev/null
+++ b/tests/data/test_file_read_data.txt
@@ -0,0 +1,7 @@
+5
+3
+0.7065937536993949 0.7016517970438980 0.1548611397288129 2
+0.4604987687863951 0.6374142980176117 0.0370930278245423 1
+0.3798777132278375 0.5745070018747664 0.2570906697837264 3
+0.2789376050039792 0.4853242744610165 0.1894010436762711 4
+0.7630904372339489 0.1341546320318005 0.6827430912944857 3
diff --git a/tests/data/test_file_read_data_no_label.txt b/tests/data/test_file_read_data_no_label.txt
new file mode 100644
index 0000000..7d11fed
--- /dev/null
+++ b/tests/data/test_file_read_data_no_label.txt
@@ -0,0 +1,7 @@
+5
+3
+0.7065937536993949 0.7016517970438980 0.1548611397288129
+0.4604987687863951 0.6374142980176117 0.0370930278245423
+0.3798777132278375 0.5745070018747664 0.2570906697837264
+0.2789376050039792 0.4853242744610165 0.1894010436762711
+0.7630904372339489 0.1341546320318005 0.6827430912944857
diff --git a/tests/data/test_read_model.txt b/tests/data/test_read_model.txt
new file mode 100644
index 0000000..a33f43b
--- /dev/null
+++ b/tests/data/test_read_model.txt
@@ -0,0 +1,20 @@
+Output file for GenSVM (version 0.1)
+Generated on: Fri Sep 16 18:24:18 2016 (UTC +02:00)
+
+Model:
+p = 0.21398
+lambda = 0.902131
+kappa = 1.0213
+epsilon = 1e-10
+weight_idx = 2
+
+Data:
+filename = ./data/test_file_read_data.txt
+n = 10
+m = 2
+K = 3
+
+Output:
++0.1234 -0.4321
+-0.5678 +0.9876
++0.9012 -0.5555
diff --git a/tests/data/test_write_model.txt b/tests/data/test_write_model.txt
new file mode 100644
index 0000000..15e7372
--- /dev/null
+++ b/tests/data/test_write_model.txt
@@ -0,0 +1,20 @@
+Output file for GenSVM (version 0.1)
+Generated on: Tue Sep 20 16:30:03 2016 (UTC +02:00)
+
+Model:
+p = 0.9032800000000000
+lambda = 0.0130000000000000
+kappa = 1.1832000000000000
+epsilon = 1e-08
+weight_idx = 1
+
+Data:
+filename = ./data/test_file_read_data.txt
+n = 10
+m = 2
+K = 3
+
+Output:
++0.4989893785603748 +0.0599082796573645
++0.7918204761759593 +0.6456613497110559
++0.9711956316284261 +0.5010714686310176
diff --git a/tests/data/test_write_predictions.txt b/tests/data/test_write_predictions.txt
new file mode 100644
index 0000000..a29c197
--- /dev/null
+++ b/tests/data/test_write_predictions.txt
@@ -0,0 +1,7 @@
+5
+3
+0.7065937536993949 0.7016517970438980 0.1548611397288129 3
+0.4604987687863951 0.6374142980176117 0.0370930278245423 2
+0.3798777132278375 0.5745070018747664 0.2570906697837264 1
+0.2789376050039792 0.4853242744610165 0.1894010436762711 2
+0.7630904372339489 0.1341546320318005 0.6827430912944857 1
diff --git a/tests/src/test_gensvm_cv_util.c b/tests/src/test_gensvm_cv_util.c
new file mode 100644
index 0000000..5cbf174
--- /dev/null
+++ b/tests/src/test_gensvm_cv_util.c
@@ -0,0 +1,257 @@
+/**
+ * @file test_gensvm_cv_util.c
+ * @author Gertjan van den Burg
+ * @date May, 2016
+ * @brief Unit tests for gensvm_cv_util.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_cv_util.h"
+
+char *test_make_cv_split_1()
+{
+ srand(0);
+ int i, j;
+ long N = 10;
+ long folds = 4;
+ long *cv_idx = Calloc(long, N);
+
+ // start test code //
+ gensvm_make_cv_split(N, folds, cv_idx);
+ // check if the values are between [0, folds-1]
+ for (i=0; i<N; i++)
+ mu_assert(0 <= cv_idx[i] && cv_idx[i] < folds,
+ "CV range incorrect.");
+
+ // check there are N % folds big folds of size floor(N/folds) + 1
+ // and the remaining are of size floor(N/folds)
+ int sum;
+ int is_big = 0,
+ is_small = 0;
+ for (i=0; i<folds; i++) {
+ sum = 0;
+ for (j=0; j<N; j++) {
+ if (cv_idx[j] == i) sum += 1;
+ }
+ if (sum == floor(N/folds) + 1)
+ is_big++;
+ else
+ is_small++;
+ }
+ mu_assert(is_big == N % folds, "Incorrect number of big folds");
+ mu_assert(is_small == folds - N % folds,
+ "Incorrect number of small folds");
+
+ // end test code //
+
+ free(cv_idx);
+
+ return NULL;
+}
+
+char *test_make_cv_split_2()
+{
+ srand(0);
+ int i, j;
+ long N = 101;
+ long folds = 7;
+ long *cv_idx = Calloc(long, N);
+
+ // start test code //
+ gensvm_make_cv_split(N, folds, cv_idx);
+ // check if the values are between [0, folds-1]
+ for (i=0; i<N; i++)
+ mu_assert(0 <= cv_idx[i] && cv_idx[i] < folds,
+ "CV range incorrect.");
+
+ // check there are N % folds big folds of size floor(N/folds) + 1
+ // and the remaining are of size floor(N/folds)
+ int sum;
+ int is_big = 0,
+ is_small = 0;
+ for (i=0; i<folds; i++) {
+ sum = 0;
+ for (j=0; j<N; j++) {
+ if (cv_idx[j] == i) sum += 1;
+ }
+ if (sum == floor(N/folds) + 1)
+ is_big++;
+ else
+ is_small++;
+ }
+ mu_assert(is_big == N % folds, "Incorrect number of big folds");
+ mu_assert(is_small == folds - N % folds,
+ "Incorrect number of small folds");
+
+ // end test code //
+
+ free(cv_idx);
+
+ return NULL;
+}
+
+
+char *test_get_tt_split()
+{
+ struct GenData *full = gensvm_init_data();
+ full->K = 3;
+ full->n = 10;
+ full->m = 2;
+ full->r = 2;
+
+ full->y = Calloc(long, full->n);
+ full->y[0] = 1;
+ full->y[1] = 2;
+ full->y[2] = 3;
+ full->y[3] = 1;
+ full->y[4] = 2;
+ full->y[5] = 3;
+ full->y[6] = 1;
+ full->y[7] = 2;
+ full->y[8] = 3;
+ full->y[9] = 1;
+
+ full->RAW = Calloc(double, full->n * (full->m+1));
+ matrix_set(full->RAW, full->m+1, 0, 1, 1.0);
+ matrix_set(full->RAW, full->m+1, 0, 2, 1.0);
+ matrix_set(full->RAW, full->m+1, 1, 1, 2.0);
+ matrix_set(full->RAW, full->m+1, 1, 2, 2.0);
+ matrix_set(full->RAW, full->m+1, 2, 1, 3.0);
+ matrix_set(full->RAW, full->m+1, 2, 2, 3.0);
+ matrix_set(full->RAW, full->m+1, 3, 1, 4.0);
+ matrix_set(full->RAW, full->m+1, 3, 2, 4.0);
+ matrix_set(full->RAW, full->m+1, 4, 1, 5.0);
+ matrix_set(full->RAW, full->m+1, 4, 2, 5.0);
+ matrix_set(full->RAW, full->m+1, 5, 1, 6.0);
+ matrix_set(full->RAW, full->m+1, 5, 2, 6.0);
+ matrix_set(full->RAW, full->m+1, 6, 1, 7.0);
+ matrix_set(full->RAW, full->m+1, 6, 2, 7.0);
+ matrix_set(full->RAW, full->m+1, 7, 1, 8.0);
+ matrix_set(full->RAW, full->m+1, 7, 2, 8.0);
+ matrix_set(full->RAW, full->m+1, 8, 1, 9.0);
+ matrix_set(full->RAW, full->m+1, 8, 2, 9.0);
+ matrix_set(full->RAW, full->m+1, 9, 1, 10.0);
+ matrix_set(full->RAW, full->m+1, 9, 2, 10.0);
+
+ long *cv_idx = Calloc(long, full->n);
+ cv_idx[0] = 1;
+ cv_idx[1] = 0;
+ cv_idx[2] = 1;
+ cv_idx[3] = 0;
+ cv_idx[4] = 1;
+ cv_idx[5] = 2;
+ cv_idx[6] = 3;
+ cv_idx[7] = 2;
+ cv_idx[8] = 3;
+ cv_idx[9] = 2;
+
+ struct GenData *train = gensvm_init_data();
+ struct GenData *test = gensvm_init_data();
+
+ // start test code //
+ gensvm_get_tt_split(full, train, test, cv_idx, 0);
+
+ mu_assert(train->n == 8, "train_n incorrect.");
+ mu_assert(test->n == 2, "test_n incorrect.");
+
+ mu_assert(train->m == 2, "train_m incorrect.");
+ mu_assert(test->m == 2, "test_m incorrect.");
+
+ mu_assert(train->K == 3, "train_K incorrect.");
+ mu_assert(test->K == 3, "test_K incorrect.");
+
+ mu_assert(train->y[0] == 1, "train y incorrect.");
+ mu_assert(train->y[1] == 3, "train y incorrect.");
+ mu_assert(train->y[2] == 2, "train y incorrect.");
+ mu_assert(train->y[3] == 3, "train y incorrect.");
+ mu_assert(train->y[4] == 1, "train y incorrect.");
+ mu_assert(train->y[5] == 2, "train y incorrect.");
+ mu_assert(train->y[6] == 3, "train y incorrect.");
+ mu_assert(train->y[7] == 1, "train y incorrect.");
+
+ mu_assert(test->y[0] == 2, "test y incorrect.");
+ mu_assert(test->y[1] == 1, "test y incorrect.");
+
+ mu_assert(matrix_get(train->RAW, train->m+1, 0, 0) == 0.0,
+ "train RAW 0, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 0, 1) == 1.0,
+ "train RAW 0, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 0, 2) == 1.0,
+ "train RAW 0, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 1, 0) == 0.0,
+ "train RAW 1, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 1, 1) == 3.0,
+ "train RAW 1, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 1, 2) == 3.0,
+ "train RAW 1, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 2, 0) == 0.0,
+ "train RAW 2, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 2, 1) == 5.0,
+ "train RAW 2, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 2, 2) == 5.0,
+ "train RAW 2, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 3, 0) == 0.0,
+ "train RAW 3, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 3, 1) == 6.0,
+ "train RAW 3, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 3, 2) == 6.0,
+ "train RAW 3, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 4, 0) == 0.0,
+ "train RAW 4, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 4, 1) == 7.0,
+ "train RAW 4, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 4, 2) == 7.0,
+ "train RAW 4, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 5, 0) == 0.0,
+ "train RAW 5, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 5, 1) == 8.0,
+ "train RAW 5, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 5, 2) == 8.0,
+ "train RAW 5, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 6, 0) == 0.0,
+ "train RAW 6, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 6, 1) == 9.0,
+ "train RAW 6, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 6, 2) == 9.0,
+ "train RAW 6, 2 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 7, 0) == 0.0,
+ "train RAW 7, 0 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 7, 1) == 10.0,
+ "train RAW 7, 1 incorrect.");
+ mu_assert(matrix_get(train->RAW, train->m+1, 7, 2) == 10.0,
+ "train RAW 7, 2 incorrect.");
+
+ mu_assert(matrix_get(test->RAW, train->m+1, 0, 0) == 0.0,
+ "test RAW 0, 0 incorrect.");
+ mu_assert(matrix_get(test->RAW, train->m+1, 0, 1) == 2.0,
+ "test RAW 0, 1 incorrect.");
+ mu_assert(matrix_get(test->RAW, train->m+1, 0, 2) == 2.0,
+ "test RAW 0, 2 incorrect.");
+ mu_assert(matrix_get(test->RAW, train->m+1, 1, 0) == 0.0,
+ "test RAW 1, 0 incorrect.");
+ mu_assert(matrix_get(test->RAW, train->m+1, 1, 1) == 4.0,
+ "test RAW 1, 1 incorrect.");
+ mu_assert(matrix_get(test->RAW, train->m+1, 1, 2) == 4.0,
+ "test RAW 1, 2 incorrect.");
+
+ // end test code //
+ gensvm_free_data(full);
+ gensvm_free_data(train);
+ gensvm_free_data(test);
+ free(cv_idx);
+
+ return NULL;
+}
+
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_make_cv_split_1);
+ mu_run_test(test_make_cv_split_2);
+ mu_run_test(test_get_tt_split);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_debug.c b/tests/src/test_gensvm_debug.c
new file mode 100644
index 0000000..ea2fbba
--- /dev/null
+++ b/tests/src/test_gensvm_debug.c
@@ -0,0 +1,60 @@
+/**
+ * @file test_gensvm_io.c
+ * @author Gertjan van den Burg
+ * @date September, 2016
+ * @brief Unit tests for gensvm_io.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_debug.h"
+
+extern FILE *GENSVM_OUTPUT_FILE;
+
+char *test_print_matrix()
+{
+ FILE *fid;
+ GENSVM_OUTPUT_FILE = fopen("./data/test_debug_print.txt", "w");
+
+ double *mat = Calloc(double, 3*2);
+ matrix_set(mat, 2, 0, 0, -0.241053050258449);
+ matrix_set(mat, 2, 0, 1, -0.599809408260836);
+ matrix_set(mat, 2, 1, 0, 0.893318163305108);
+ matrix_set(mat, 2, 1, 1, -0.344057630469285);
+ matrix_set(mat, 2, 2, 0, 0.933948479216127);
+ matrix_set(mat, 2, 2, 1, -0.474352026604967);
+
+ // start test code //
+ gensvm_print_matrix(mat, 3, 2);
+ fclose(GENSVM_OUTPUT_FILE);
+
+ char buffer[MAX_LINE_LENGTH];
+ fid = fopen("./data/test_debug_print.txt", "r");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "-0.241053 -0.599809\n") == 0,
+ "Line doesn't contain expected content (0).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "+0.893318 -0.344058\n") == 0,
+ "Line doesn't contain expected content (1).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "+0.933948 -0.474352\n") == 0,
+ "Line doesn't contain expected content (2).\n");
+
+ fclose(fid);
+ // end test code //
+ free(mat);
+
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_print_matrix);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_init.c b/tests/src/test_gensvm_init.c
new file mode 100644
index 0000000..8cacd9b
--- /dev/null
+++ b/tests/src/test_gensvm_init.c
@@ -0,0 +1,215 @@
+/**
+ * @file test_gensvm_init.c
+ * @author Gertjan van den Burg
+ * @date August, 2016
+ * @brief Unit tests for gensvm_init.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_init.h"
+
+char *test_init_null()
+{
+ int i;
+ long n = 5,
+ m = 2,
+ K = 3;
+ double value;
+ struct GenModel *model = gensvm_init_model();
+ struct GenData *data = gensvm_init_data();
+
+ // start test code
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ gensvm_allocate_model(model);
+
+ data->n = n;
+ data->m = m;
+ data->K = K;
+ data->RAW = Calloc(double, n*(m+1));
+ for (i=0; i<n; i++) {
+ matrix_set(data->RAW, m+1, i, 0, 1.0);
+ matrix_set(data->RAW, m+1, i, 1, i);
+ matrix_set(data->RAW, m+1, i, 2, -i);
+ }
+ data->Z = data->RAW;
+
+ gensvm_init_V(NULL, model, data);
+
+ // first row all ones
+ value = matrix_get(model->V, K-1, 0, 0);
+ mu_assert(value == 1.0, "Incorrect value at 0, 0");
+ value = matrix_get(model->V, K-1, 0, 1);
+ mu_assert(value == 1.0, "Incorrect value at 0, 1");
+
+ // second row between -1 and 0.25
+ value = matrix_get(model->V, K-1, 1, 0);
+ mu_assert(value >= -1.0 && value <= 0.25, "Incorrect value at 1, 0");
+ value = matrix_get(model->V, K-1, 1, 1);
+ mu_assert(value >= -1.0 && value <= 0.25, "Incorrect value at 1, 1");
+
+ // third row between -0.25 and 1
+ value = matrix_get(model->V, K-1, 2, 0);
+ mu_assert(value >= -0.25 && value <= 1.0, "Incorrect value at 2, 0");
+ value = matrix_get(model->V, K-1, 2, 1);
+ mu_assert(value >= -0.25 && value <= 1.0, "Incorrect value at 2, 1");
+
+ // end test code
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_init_seed()
+{
+
+ long n = 7,
+ m = 5,
+ K = 3;
+ struct GenModel *model = gensvm_init_model();
+ struct GenModel *seed = gensvm_init_model();
+ struct GenData *data = gensvm_init_data();
+
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ seed->n = n;
+ seed->m = m;
+ seed->K = K;
+ data->n = n;
+ data->m = m;
+ data->K = K;
+ gensvm_allocate_model(model);
+ gensvm_allocate_model(seed);
+
+ // start test code
+ matrix_set(seed->V, seed->K-1, 0, 0, 123.0);
+ matrix_set(seed->V, seed->K-1, 1, 1, 321.0);
+ matrix_set(seed->V, seed->K-1, 2, 0, 111.0);
+ matrix_set(seed->V, seed->K-1, 5, 0, 222.0);
+ matrix_set(seed->V, seed->K-1, 3, 1, 333.0);
+
+ gensvm_init_V(seed, model, data);
+
+ mu_assert(matrix_get(model->V, model->K-1, 0, 0) == 123.0,
+ "Incorrect V value at 0, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 0, 1) == 0.0,
+ "Incorrect V value at 0, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 1, 0) == 0.0,
+ "Incorrect V value at 1, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 1, 1) == 321.0,
+ "Incorrect V value at 1, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 2, 0) == 111.0,
+ "Incorrect V value at 2, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 2, 1) == 0.0,
+ "Incorrect V value at 2, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 3, 0) == 0.0,
+ "Incorrect V value at 3, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 3, 1) == 333.0,
+ "Incorrect V value at 3, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 4, 0) == 0.0,
+ "Incorrect V value at 4, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 4, 1) == 0.0,
+ "Incorrect V value at 4, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 5, 0) == 222.0,
+ "Incorrect V value at 5, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 5, 1) == 0.0,
+ "Incorrect V value at 5, 1");
+ // end test code
+
+ gensvm_free_model(model);
+ gensvm_free_model(seed);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_init_weights_1()
+{
+ struct GenModel *model = gensvm_init_model();
+ struct GenData *data = NULL;
+ model->n = 7;
+ model->m = 5;
+ model->K = 3;
+ model->weight_idx = 1;
+ gensvm_allocate_model(model);
+
+ // start test code
+ int i;
+ gensvm_initialize_weights(data, model);
+ for (i=0; i<model->n; i++) {
+ mu_assert(model->rho[i] == 1.0, "incorrect weight in rho");
+ }
+ // end test code
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_init_weights_2()
+{
+ struct GenModel *model = gensvm_init_model();
+ struct GenData *data = gensvm_init_data();
+ model->n = 8;
+ model->m = 5;
+ model->K = 3;
+ model->weight_idx = 2;
+ gensvm_allocate_model(model);
+
+ data->y = Calloc(long, model->n);
+
+ data->y[0] = 1;
+ data->y[1] = 1;
+ data->y[2] = 1;
+ data->y[3] = 1;
+ data->y[4] = 2;
+ data->y[5] = 2;
+ data->y[6] = 2;
+ data->y[7] = 3;
+
+ // start test code
+ gensvm_initialize_weights(data, model);
+ int i;
+ for (i=0; i<4; i++) {
+ mu_assert(model->rho[i] == 8.0/(4.0 * 3.0),
+ "Incorrect weight for class 1");
+ }
+ for (i=0; i<3; i++) {
+ mu_assert(model->rho[4+i] == 8.0/(3.0 * 3.0),
+ "Incorrect weight for class 2");
+ }
+ mu_assert(model->rho[7] == 8.0/(1.0 * 3.0),
+ "Incorrect weight for class 3");
+
+ // end test code
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_init_weights_wrong()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_init_null);
+ mu_run_test(test_init_seed);
+ mu_run_test(test_init_weights_1);
+ mu_run_test(test_init_weights_2);
+ mu_run_test(test_init_weights_wrong);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_io.c b/tests/src/test_gensvm_io.c
new file mode 100644
index 0000000..c97e2d8
--- /dev/null
+++ b/tests/src/test_gensvm_io.c
@@ -0,0 +1,431 @@
+/**
+ * @file test_gensvm_io.c
+ * @author Gertjan van den Burg
+ * @date September, 2016
+ * @brief Unit tests for gensvm_io.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_io.h"
+
+char *test_gensvm_read_data()
+{
+ char *filename = "./data/test_file_read_data.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 5, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 4, "Incorrect value for K");
+
+ // check if all data is read correctly.
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949,
+ "Incorrect Z value at 0, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980,
+ "Incorrect Z value at 0, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129,
+ "Incorrect Z value at 0, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951,
+ "Incorrect Z value at 1, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117,
+ "Incorrect Z value at 1, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423,
+ "Incorrect Z value at 1, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375,
+ "Incorrect Z value at 2, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664,
+ "Incorrect Z value at 2, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264,
+ "Incorrect Z value at 2, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792,
+ "Incorrect Z value at 3, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165,
+ "Incorrect Z value at 3, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711,
+ "Incorrect Z value at 3, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489,
+ "Incorrect Z value at 4, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005,
+ "Incorrect Z value at 4, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857,
+ "Incorrect Z value at 4, 3");
+ // check if RAW = Z
+ mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer");
+
+ // check if labels read correctly
+ mu_assert(data->y[0] == 2, "Incorrect label read at 0");
+ mu_assert(data->y[1] == 1, "Incorrect label read at 1");
+ mu_assert(data->y[2] == 3, "Incorrect label read at 2");
+ mu_assert(data->y[3] == 4, "Incorrect label read at 3");
+ mu_assert(data->y[4] == 3, "Incorrect label read at 4");
+
+ // check if the column of ones is added
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1,
+ "Incorrect Z value at 0, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1,
+ "Incorrect Z value at 1, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1,
+ "Incorrect Z value at 2, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1,
+ "Incorrect Z value at 3, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1,
+ "Incorrect Z value at 4, 0");
+
+ // end test code //
+
+ gensvm_free_data(data);
+ return NULL;
+}
+
+char *test_gensvm_read_data_no_label()
+{
+ char *filename = "./data/test_file_read_data_no_label.txt";
+ struct GenData *data = gensvm_init_data();
+
+ // start test code //
+ gensvm_read_data(data, filename);
+
+ // check if dimensions are correctly read
+ mu_assert(data->n == 5, "Incorrect value for n");
+ mu_assert(data->m == 3, "Incorrect value for m");
+ mu_assert(data->r == 3, "Incorrect value for r");
+ mu_assert(data->K == 0, "Incorrect value for K");
+
+ // check if all data is read correctly.
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 1) == 0.7065937536993949,
+ "Incorrect Z value at 0, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 2) == 0.7016517970438980,
+ "Incorrect Z value at 0, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 3) == 0.1548611397288129,
+ "Incorrect Z value at 0, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 1) == 0.4604987687863951,
+ "Incorrect Z value at 1, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 2) == 0.6374142980176117,
+ "Incorrect Z value at 1, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 3) == 0.0370930278245423,
+ "Incorrect Z value at 1, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 1) == 0.3798777132278375,
+ "Incorrect Z value at 2, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 2) == 0.5745070018747664,
+ "Incorrect Z value at 2, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 3) == 0.2570906697837264,
+ "Incorrect Z value at 2, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 1) == 0.2789376050039792,
+ "Incorrect Z value at 3, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 2) == 0.4853242744610165,
+ "Incorrect Z value at 3, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 3) == 0.1894010436762711,
+ "Incorrect Z value at 3, 3");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 1) == 0.7630904372339489,
+ "Incorrect Z value at 4, 1");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 2) == 0.1341546320318005,
+ "Incorrect Z value at 4, 2");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 3) == 0.6827430912944857,
+ "Incorrect Z value at 4, 3");
+ // check if RAW = Z
+ mu_assert(data->Z == data->RAW, "Z pointer doesn't equal RAW pointer");
+
+ // check if labels read correctly
+ mu_assert(data->y == NULL, "Outcome pointer is not NULL");
+
+ // check if the column of ones is added
+ mu_assert(matrix_get(data->Z, data->m+1, 0, 0) == 1,
+ "Incorrect Z value at 0, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 1, 0) == 1,
+ "Incorrect Z value at 1, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 2, 0) == 1,
+ "Incorrect Z value at 2, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 3, 0) == 1,
+ "Incorrect Z value at 3, 0");
+ mu_assert(matrix_get(data->Z, data->m+1, 4, 0) == 1,
+ "Incorrect Z value at 4, 0");
+
+ // end test code //
+
+ gensvm_free_data(data);
+ return NULL;
+}
+
+char *test_gensvm_read_model()
+{
+ struct GenModel *model = gensvm_init_model();
+ char *filename = "./data/test_read_model.txt";
+
+ // start test code //
+ gensvm_read_model(model, filename);
+
+ mu_assert(model->p == 0.21398, "Incorrect read for model->p");
+ mu_assert(model->lambda == 0.902131, "Incorrect read for "
+ "model->lambda");
+ mu_assert(model->kappa == 1.0213, "Incorrect read for model->kappa");
+ mu_assert(model->epsilon == 1e-10, "Incorrect read for "
+ "model->epsilon");
+ mu_assert(model->weight_idx == 2, "Incorrect read for "
+ "model->weight_idx");
+
+ mu_assert(strcmp(model->data_file, "./data/test_file_read_data.txt")
+ == 0, "Incorrect read for model->data_file");
+ mu_assert(model->n == 10, "Incorrect read for model->n");
+ mu_assert(model->m == 2, "Incorrect read for model->m");
+ mu_assert(model->K == 3, "Incorrect read for model->K");
+
+ mu_assert(matrix_get(model->V, model->K-1, 0, 0) == 0.1234,
+ "Incorrect model->V element at 0, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 0, 1) == -0.4321,
+ "Incorrect model->V element at 0, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 1, 0) == -0.5678,
+ "Incorrect model->V element at 1, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 1, 1) == 0.9876,
+ "Incorrect model->V element at 1, 1");
+ mu_assert(matrix_get(model->V, model->K-1, 2, 0) == 0.9012,
+ "Incorrect model->V element at 2, 0");
+ mu_assert(matrix_get(model->V, model->K-1, 2, 1) == -0.5555,
+ "Incorrect model->V element at 2, 1");
+
+ // end test code //
+
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_gensvm_write_model()
+{
+ struct GenModel *model = gensvm_init_model();
+
+ model->p = 0.90328;
+ model->lambda = 0.0130;
+ model->kappa = 1.1832;
+ model->epsilon = 1e-8;
+ model->weight_idx = 1;
+ model->data_file = strdup("./data/test_file_read_data.txt");
+ model->n = 10;
+ model->m = 2;
+ model->K = 3;
+
+ model->V = Calloc(double, (model->m+1)*(model->K-1));
+ matrix_set(model->V, model->K-1, 0, 0, 0.4989893785603748);
+ matrix_set(model->V, model->K-1, 0, 1, 0.0599082796573645);
+ matrix_set(model->V, model->K-1, 1, 0, 0.7918204761759593);
+ matrix_set(model->V, model->K-1, 1, 1, 0.6456613497110559);
+ matrix_set(model->V, model->K-1, 2, 0, 0.9711956316284261);
+ matrix_set(model->V, model->K-1, 2, 1, 0.5010714686310176);
+
+ // start test code //
+ gensvm_write_model(model, "./data/test_write_model.txt");
+
+ FILE *fid = fopen("./data/test_write_model.txt", "r");
+ mu_assert(fid != NULL, "Couldn't open output file for reading");
+
+ char buffer[MAX_LINE_LENGTH];
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "Output file for GenSVM (version 0.1)\n")
+ == 0, "Line doesn't contain expected content (0).\n");
+
+ // skip the time line
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "\n") == 0,
+ "Line doesn't contain expected content (1).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "Model:\n") == 0,
+ "Line doesn't contain expected content (2).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "p = 0.9032800000000000\n") == 0,
+ "Line doesn't contain expected content (3).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "lambda = 0.0130000000000000\n") == 0,
+ "Line doesn't contain expected content (4).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "kappa = 1.1832000000000000\n") == 0,
+ "Line doesn't contain expected content (5).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "epsilon = 1e-08\n") == 0,
+ "Line doesn't contain expected content (6).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "weight_idx = 1\n") == 0,
+ "Line doesn't contain expected content (7).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "\n") == 0,
+ "Line doesn't contain expected content (8).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "Data:\n") == 0,
+ "Line doesn't contain expected content (9).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "filename = ./data/test_file_read_data.txt\n")
+ == 0, "Line doesn't contain expected content (10).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "n = 10\n") == 0,
+ "Line doesn't contain expected content (11).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "m = 2\n") == 0,
+ "Line doesn't contain expected content (12).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "K = 3\n") == 0,
+ "Line doesn't contain expected content (13).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "\n") == 0,
+ "Line doesn't contain expected content (14).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "Output:\n") == 0,
+ "Line doesn't contain expected content (15).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "+0.4989893785603748 +0.0599082796573645\n")
+ == 0, "Line doesn't contain expected content (16).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "+0.7918204761759593 +0.6456613497110559\n")
+ == 0, "Line doesn't contain expected content (17).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "+0.9711956316284261 +0.5010714686310176\n")
+ == 0, "Line doesn't contain expected content (18).\n");
+
+ fclose(fid);
+
+ // end test code //
+
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_gensvm_write_predictions()
+{
+ int n = 5,
+ m = 3,
+ K = 3;
+ struct GenData *data = gensvm_init_data();
+ long *predy = Calloc(long, n);
+
+ data->n = n;
+ data->m = m;
+ data->K = K;
+
+ data->Z = Calloc(double, data->n * (data->m+1));
+
+ matrix_set(data->Z, data->m+1, 0, 0, 1.0);
+ matrix_set(data->Z, data->m+1, 1, 0, 1.0);
+ matrix_set(data->Z, data->m+1, 2, 0, 1.0);
+ matrix_set(data->Z, data->m+1, 3, 0, 1.0);
+ matrix_set(data->Z, data->m+1, 4, 0, 1.0);
+
+ matrix_set(data->Z, data->m+1, 0, 1, 0.7065937536993949);
+ matrix_set(data->Z, data->m+1, 0, 2, 0.7016517970438980);
+ matrix_set(data->Z, data->m+1, 0, 3, 0.1548611397288129);
+ matrix_set(data->Z, data->m+1, 1, 1, 0.4604987687863951);
+ matrix_set(data->Z, data->m+1, 1, 2, 0.6374142980176117);
+ matrix_set(data->Z, data->m+1, 1, 3, 0.0370930278245423);
+ matrix_set(data->Z, data->m+1, 2, 1, 0.3798777132278375);
+ matrix_set(data->Z, data->m+1, 2, 2, 0.5745070018747664);
+ matrix_set(data->Z, data->m+1, 2, 3, 0.2570906697837264);
+ matrix_set(data->Z, data->m+1, 3, 1, 0.2789376050039792);
+ matrix_set(data->Z, data->m+1, 3, 2, 0.4853242744610165);
+ matrix_set(data->Z, data->m+1, 3, 3, 0.1894010436762711);
+ matrix_set(data->Z, data->m+1, 4, 1, 0.7630904372339489);
+ matrix_set(data->Z, data->m+1, 4, 2, 0.1341546320318005);
+ matrix_set(data->Z, data->m+1, 4, 3, 0.6827430912944857);
+
+ predy[0] = 3;
+ predy[1] = 2;
+ predy[2] = 1;
+ predy[3] = 2;
+ predy[4] = 1;
+
+ // start test code //
+
+ gensvm_write_predictions(data, predy,
+ "./data/test_write_predictions.txt");
+
+ FILE *fid = fopen("./data/test_write_predictions.txt", "r");
+ mu_assert(fid != NULL, "Couldn't open output file for reading");
+
+ char buffer[MAX_LINE_LENGTH];
+ // skip the first two lines
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "5\n") == 0,
+ "Line doesn't contain expected content (0).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "3\n") == 0,
+ "Line doesn't contain expected content (1).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "0.7065937536993949 0.7016517970438980 "
+ "0.1548611397288129 3\n") == 0,
+ "Line doesn't contain expected content (2).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "0.4604987687863951 0.6374142980176117 "
+ "0.0370930278245423 2\n") == 0,
+ "Line doesn't contain expected content (3).\n");
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "0.3798777132278375 0.5745070018747664 "
+ "0.2570906697837264 1\n") == 0,
+ "Line doesn't contain expected content (4).\n");
+
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "0.2789376050039792 0.4853242744610165 "
+ "0.1894010436762711 2\n") == 0,
+ "Line doesn't contain expected content (5).\n");
+
+
+ fgets(buffer, MAX_LINE_LENGTH, fid);
+ mu_assert(strcmp(buffer, "0.7630904372339489 0.1341546320318005 "
+ "0.6827430912944857 1\n") == 0,
+ "Line doesn't contain expected content (6).\n");
+
+ fclose(fid);
+ // end test code //
+
+ gensvm_free_data(data);
+ free(predy);
+
+ return NULL;
+}
+
+char *test_gensvm_time_string()
+{
+ // not sure how to unit test this function.
+
+ mu_test_missing();
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_gensvm_read_data);
+ mu_run_test(test_gensvm_read_data_no_label);
+ mu_run_test(test_gensvm_read_model);
+ mu_run_test(test_gensvm_write_model);
+ mu_run_test(test_gensvm_write_predictions);
+ mu_run_test(test_gensvm_time_string);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_optimize.c b/tests/src/test_gensvm_optimize.c
new file mode 100644
index 0000000..dda6d08
--- /dev/null
+++ b/tests/src/test_gensvm_optimize.c
@@ -0,0 +1,450 @@
+/**
+ * @file test_gensvm_optimize.c
+ * @author Gertjan van den Burg
+ * @date September, 2016
+ * @brief Unit tests for gensvm_optimize.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_optimize.h"
+
+char *test_gensvm_optimize()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *test_gensvm_get_loss()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *test_gensvm_get_update()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *test_gensvm_category_matrix()
+{
+ struct GenModel *model = gensvm_init_model();
+ struct GenData *data = gensvm_init_data();
+
+ int n = 5,
+ m = 3,
+ K = 3;
+
+ data->n = n;
+ data->m = m;
+ data->K = K;
+
+ model->n = n;
+ model->m = m;
+ model->K = K;
+
+ gensvm_allocate_model(model);
+ data->y = Calloc(long, data->n);
+ data->y[0] = 1;
+ data->y[1] = 2;
+ data->y[2] = 3;
+ data->y[3] = 2;
+ data->y[4] = 1;
+
+ // start test code //
+
+ gensvm_category_matrix(model, data);
+
+ mu_assert(matrix_get(model->R, K, 0, 0) == 0, "Incorrect R at 0, 0");
+ mu_assert(matrix_get(model->R, K, 0, 1) == 1, "Incorrect R at 0, 1");
+ mu_assert(matrix_get(model->R, K, 0, 2) == 1, "Incorrect R at 0, 2");
+
+ mu_assert(matrix_get(model->R, K, 1, 0) == 1, "Incorrect R at 1, 0");
+ mu_assert(matrix_get(model->R, K, 1, 1) == 0, "Incorrect R at 1, 1");
+ mu_assert(matrix_get(model->R, K, 1, 2) == 1, "Incorrect R at 1, 2");
+
+ mu_assert(matrix_get(model->R, K, 2, 0) == 1, "Incorrect R at 2, 0");
+ mu_assert(matrix_get(model->R, K, 2, 1) == 1, "Incorrect R at 2, 1");
+ mu_assert(matrix_get(model->R, K, 2, 2) == 0, "Incorrect R at 2, 2");
+
+ mu_assert(matrix_get(model->R, K, 3, 0) == 1, "Incorrect R at 3, 0");
+ mu_assert(matrix_get(model->R, K, 3, 1) == 0, "Incorrect R at 3, 1");
+ mu_assert(matrix_get(model->R, K, 3, 2) == 1, "Incorrect R at 3, 2");
+
+ mu_assert(matrix_get(model->R, K, 4, 0) == 0, "Incorrect R at 4, 0");
+ mu_assert(matrix_get(model->R, K, 4, 1) == 1, "Incorrect R at 4, 1");
+ mu_assert(matrix_get(model->R, K, 4, 2) == 1, "Incorrect R at 4, 2");
+
+ // end test code //
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_gensvm_simplex_diff()
+{
+ struct GenData *data = gensvm_init_data();
+ struct GenModel *model = gensvm_init_model();
+
+ int n = 8,
+ m = 3,
+ K = 3;
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ data->n = n;
+ data->m = m;
+ data->K = K;
+
+ data->y = Calloc(long, n);
+
+ gensvm_allocate_model(model);
+ gensvm_simplex(model->K, model->U);
+
+ data->y[0] = 2;
+ data->y[1] = 1;
+ data->y[2] = 3;
+ data->y[3] = 2;
+ data->y[4] = 3;
+ data->y[5] = 3;
+ data->y[6] = 1;
+ data->y[7] = 2;
+
+ // start test code //
+ gensvm_simplex_diff(model, data);
+
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 0) -
+ 1.0000000000000000) < 1e-14,
+ "Incorrect value at UU(0, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(0, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 2) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(0, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(0, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(0, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 2) -
+ -0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(0, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(1, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 1) -
+ -1.0000000000000000) < 1e-14,
+ "Incorrect value at UU(1, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 2) -
+ -0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(1, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(1, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(1, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 2) -
+ -0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(1, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 0) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(2, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 1) -
+ -0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(2, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(2, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 0) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(2, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 1) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(2, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(2, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 0) -
+ 1.0000000000000000) < 1e-14,
+ "Incorrect value at UU(3, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(3, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 2) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(3, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(3, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(3, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 2) -
+ -0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(3, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 0) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(4, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 1) -
+ -0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(4, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(4, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 0) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(4, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 1) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(4, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(4, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 0) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(5, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 1) -
+ -0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(5, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(5, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 0) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(5, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 1) -
+ 0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(5, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 2) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(5, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(6, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 1) -
+ -1.0000000000000000) < 1e-14,
+ "Incorrect value at UU(6, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 2) -
+ -0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(6, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(6, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(6, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 2) -
+ -0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(6, 1, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 0) -
+ 1.0000000000000000) < 1e-14,
+ "Incorrect value at UU(7, 0, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(7, 0, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 2) -
+ 0.5000000000000000) < 1e-14,
+ "Incorrect value at UU(7, 0, 2)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(7, 1, 0)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect value at UU(7, 1, 1)");
+ mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 2) -
+ -0.8660254037844388) < 1e-14,
+ "Incorrect value at UU(7, 1, 2)");
+ // end test code //
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *test_gensvm_calculate_errors()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *test_gensvm_calculate_huber()
+{
+ struct GenModel *model = gensvm_init_model();
+ model->n = 5;
+ model->m = 3;
+ model->K = 3;
+ model->kappa = 0.5;
+
+ gensvm_allocate_model(model);
+
+ matrix_set(model->Q, model->K, 0, 0, -0.3386242674244120);
+ matrix_set(model->Q, model->K, 0, 1, 1.0828252163937386);
+ matrix_set(model->Q, model->K, 0, 2, 0.9734009993634181);
+ matrix_set(model->Q, model->K, 1, 0, 1.3744927461858576);
+ matrix_set(model->Q, model->K, 1, 1, 1.8086820272988162);
+ matrix_set(model->Q, model->K, 1, 2, 0.9587412628706828);
+ matrix_set(model->Q, model->K, 2, 0, -0.0530412768492290);
+ matrix_set(model->Q, model->K, 2, 1, 0.4026826962708268);
+ matrix_set(model->Q, model->K, 2, 2, -1.9705914880746673);
+ matrix_set(model->Q, model->K, 3, 0, -0.8749982403551375);
+ matrix_set(model->Q, model->K, 3, 1, 1.3981525936474806);
+ matrix_set(model->Q, model->K, 3, 2, 0.5845478158465323);
+ matrix_set(model->Q, model->K, 4, 0, 0.9594104113136890);
+ matrix_set(model->Q, model->K, 4, 1, -0.7058945833639207);
+ matrix_set(model->Q, model->K, 4, 2, -1.8413342248272893);
+
+ // start test code //
+ gensvm_calculate_huber(model);
+
+ mu_assert(fabs(matrix_get(model->H, model->K, 0, 0) -
+ 0.5973049764458478) < 1e-14,
+ "Incorrect H at 0, 0");
+ mu_assert(fabs(matrix_get(model->H, model->K, 0, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect H at 0, 1");
+ mu_assert(fabs(matrix_get(model->H, model->K, 0, 2) -
+ 0.0002358356116216) < 1e-14,
+ "Incorrect H at 0, 2");
+ mu_assert(fabs(matrix_get(model->H, model->K, 1, 0) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect H at 1, 0");
+ mu_assert(fabs(matrix_get(model->H, model->K, 1, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect H at 1, 1");
+ mu_assert(fabs(matrix_get(model->H, model->K, 1, 2) -
+ 0.0005674277965020) < 1e-14,
+ "Incorrect H at 1, 2");
+ mu_assert(fabs(matrix_get(model->H, model->K, 2, 0) -
+ 0.3696319769160848) < 1e-14,
+ "Incorrect H at 2, 0");
+ mu_assert(fabs(matrix_get(model->H, model->K, 2, 1) -
+ 0.1189293204447631) < 1e-14,
+ "Incorrect H at 2, 1");
+ mu_assert(fabs(matrix_get(model->H, model->K, 2, 2) -
+ 2.2205914880746676) < 1e-14,
+ "Incorrect H at 2, 2");
+ mu_assert(fabs(matrix_get(model->H, model->K, 3, 0) -
+ 1.1249982403551375) < 1e-14,
+ "Incorrect H at 3, 0");
+ mu_assert(fabs(matrix_get(model->H, model->K, 3, 1) -
+ 0.0000000000000000) < 1e-14,
+ "Incorrect H at 3, 1");
+ mu_assert(fabs(matrix_get(model->H, model->K, 3, 2) -
+ 0.0575335057726290) < 1e-14,
+ "Incorrect H at 3, 2");
+ mu_assert(fabs(matrix_get(model->H, model->K, 4, 0) -
+ 0.0005491715699080) < 1e-14,
+ "Incorrect H at 4, 0");
+ mu_assert(fabs(matrix_get(model->H, model->K, 4, 1) -
+ 0.9558945833639207) < 1e-14,
+ "Incorrect H at 4, 1");
+ mu_assert(fabs(matrix_get(model->H, model->K, 4, 2) -
+ 2.0913342248272890) < 1e-14,
+ "Incorrect H at 4, 2");
+
+ // start end code //
+
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_gensvm_step_doubling()
+{
+ struct GenModel *model = gensvm_init_model();
+ model->n = 5;
+ model->m = 3;
+ model->K = 3;
+
+ gensvm_allocate_model(model);
+
+ // start test code //
+
+ matrix_set(model->V, model->K-1, 0, 0, 0.4534886648299394);
+ matrix_set(model->V, model->K-1, 0, 1, 0.0744278734246461);
+ matrix_set(model->V, model->K-1, 1, 0, 0.3119251404109698);
+ matrix_set(model->V, model->K-1, 1, 1, 0.4656439597720683);
+ matrix_set(model->V, model->K-1, 2, 0, 0.2011718791962903);
+ matrix_set(model->V, model->K-1, 2, 1, 0.6500799120482493);
+ matrix_set(model->V, model->K-1, 3, 0, 0.4203721613186416);
+ matrix_set(model->V, model->K-1, 3, 1, 0.3322561487796912);
+
+ matrix_set(model->Vbar, model->K-1, 0, 0, 0.4716285304362394);
+ matrix_set(model->Vbar, model->K-1, 0, 1, 0.9580963971307287);
+ matrix_set(model->Vbar, model->K-1, 1, 0, 0.4665492786460124);
+ matrix_set(model->Vbar, model->K-1, 1, 1, 0.7584128769293659);
+ matrix_set(model->Vbar, model->K-1, 2, 0, 0.0694310200377222);
+ matrix_set(model->Vbar, model->K-1, 2, 1, 0.9819492320913891);
+ matrix_set(model->Vbar, model->K-1, 3, 0, 0.9308026068356582);
+ matrix_set(model->Vbar, model->K-1, 3, 1, 0.1323286413371241);
+
+ gensvm_step_doubling(model);
+
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 0) -
+ 0.4353487992236394) < 1e-14,
+ "Incorrect V at 0, 0");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 1) -
+ -0.8092406502814364) < 1e-14,
+ "Incorrect V at 0, 1");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 1, 0) -
+ 0.1573010021759272) < 1e-14,
+ "Incorrect V at 1, 0");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 1, 1) -
+ 0.1728750426147708) < 1e-14,
+ "Incorrect V at 1, 1");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 2, 0) -
+ 0.3329127383548583) < 1e-14,
+ "Incorrect V at 2, 0");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 2, 1) -
+ 0.3182105920051096) < 1e-14,
+ "Incorrect V at 2, 1");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 0) -
+ -0.0900582841983750) < 1e-14,
+ "Incorrect V at 3, 0");
+ mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 1) -
+ 0.5321836562222582) < 1e-14,
+ "Incorrect V at 3, 1");
+
+ // end test code //
+
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_dposv()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *test_dsysv()
+{
+ mu_test_missing();
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_gensvm_optimize);
+ mu_run_test(test_gensvm_get_loss);
+ mu_run_test(test_gensvm_get_update);
+ mu_run_test(test_gensvm_category_matrix);
+ mu_run_test(test_gensvm_simplex_diff);
+ mu_run_test(test_gensvm_calculate_errors);
+ mu_run_test(test_gensvm_calculate_huber);
+ mu_run_test(test_gensvm_step_doubling);
+ mu_run_test(test_dposv);
+ mu_run_test(test_dsysv);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_pred.c b/tests/src/test_gensvm_pred.c
new file mode 100644
index 0000000..13f0e5a
--- /dev/null
+++ b/tests/src/test_gensvm_pred.c
@@ -0,0 +1,165 @@
+/**
+ * @file test_gensvm_pred.c
+ * @author Gertjan van den Burg
+ * @date September, 2016
+ * @brief Unit tests for gensvm_pred.c functions
+ */
+#include "minunit.h"
+#include "gensvm_pred.h"
+
+/**
+ * This testcase is designed as follows: 12 evenly spaced points are plotted
+ * on the unit circle in the simplex space. These points are the ones for
+ * which we want to predict the class. To get these points, we need a Z and a
+ * V which map to these points. To get this, Z was equal to [1 Q] and V was
+ * equal to [0; R] where Q and R are from the reduced QR decomposition of the
+ * 12x2 matrix S which contains the points in simplex space. Here's the
+ * Matlab/Octave code to generate this data:
+ *
+ * n = 12;
+ * K = 3;
+ * S = [cos(1/12*pi+1/6*pi*[0:(n-1)])', sin(1/12*pi+1/6*pi*[0:(n-1)])'];
+ * [Q, R] = qr(S, '0');
+ * Z = [ones(n, 1), Q];
+ * V = [zeros(1, K-1); R];
+ *
+ */
+char *test_gensvm_predict_labels()
+{
+ int n = 12;
+ int m = 2;
+ int K = 3;
+
+ struct GenData *data = gensvm_init_data();
+ struct GenModel *model = gensvm_init_model();
+
+ model->n = n;
+ model->m = m;
+ model->K = K;
+
+ data->n = n;
+ data->m = m;
+ data->r = m;
+ data->K = K;
+
+ data->Z = Calloc(double, n*(m+1));
+ data->y = Calloc(long, n);
+
+ matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065);
+ matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935);
+ matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129);
+ matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128);
+ matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937);
+ matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063);
+ matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935);
+ matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064);
+ matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129);
+ matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129);
+ matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064);
+ matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937);
+ matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065);
+ matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935);
+ matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130);
+ matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128);
+ matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939);
+ matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063);
+ matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934);
+ matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064);
+ matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126);
+ matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132);
+ matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064);
+ matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939);
+
+ gensvm_allocate_model(model);
+
+ matrix_set(model->V, K-1, 0, 0, 0.0000000000000000);
+ matrix_set(model->V, K-1, 0, 1, 0.0000000000000000);
+ matrix_set(model->V, K-1, 1, 0, -2.4494897427831779);
+ matrix_set(model->V, K-1, 1, 1, -0.0000000000000002);
+ matrix_set(model->V, K-1, 2, 0, 0.0000000000000000);
+ matrix_set(model->V, K-1, 2, 1, -2.4494897427831783);
+
+ // start test code
+ long *predy = Calloc(long, n);
+ gensvm_predict_labels(data, model, predy);
+
+ mu_assert(predy[0] == 2, "Incorrect label at index 0");
+ mu_assert(predy[1] == 3, "Incorrect label at index 1");
+ mu_assert(predy[2] == 3, "Incorrect label at index 2");
+ mu_assert(predy[3] == 3, "Incorrect label at index 3");
+ mu_assert(predy[4] == 3, "Incorrect label at index 4");
+ mu_assert(predy[5] == 1, "Incorrect label at index 5");
+ mu_assert(predy[6] == 1, "Incorrect label at index 6");
+ mu_assert(predy[7] == 1, "Incorrect label at index 7");
+ mu_assert(predy[8] == 1, "Incorrect label at index 8");
+ mu_assert(predy[9] == 2, "Incorrect label at index 9");
+ mu_assert(predy[10] == 2, "Incorrect label at index 10");
+ mu_assert(predy[11] == 2, "Incorrect label at index 11");
+
+ // end test code
+ gensvm_free_data(data);
+ gensvm_free_model(model);
+ free(predy);
+
+ return NULL;
+}
+
+char *test_gensvm_prediction_perf()
+{
+ int i, n = 8;
+ struct GenData *data = gensvm_init_data();
+ data->n = n;
+ data->y = Calloc(long, n);
+ data->y[0] = 1;
+ data->y[1] = 1;
+ data->y[2] = 1;
+ data->y[3] = 1;
+ data->y[4] = 2;
+ data->y[5] = 2;
+ data->y[6] = 2;
+ data->y[7] = 3;
+
+ long *y = Calloc(long, n);
+ for (i=0; i<n; i++)
+ y[i] = 1;
+ mu_assert(gensvm_prediction_perf(data, y) == 50.0,
+ "Incorrect first time.");
+
+ for (i=0; i<n; i++)
+ y[i] = 2;
+ mu_assert(gensvm_prediction_perf(data, y) == 37.5,
+ "Incorrect second time.");
+
+ for (i=0; i<n; i++)
+ y[i] = 3;
+ mu_assert(gensvm_prediction_perf(data, y) == 12.5,
+ "Incorrect third time.");
+
+ free(y);
+ gensvm_free_data(data);
+
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_gensvm_predict_labels);
+ mu_run_test(test_gensvm_prediction_perf);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);
diff --git a/tests/src/test_gensvm_sv.c b/tests/src/test_gensvm_sv.c
new file mode 100644
index 0000000..f6ea0e2
--- /dev/null
+++ b/tests/src/test_gensvm_sv.c
@@ -0,0 +1,62 @@
+/**
+ * @file test_gensvm_sv.c
+ * @author Gertjan van den Burg
+ * @date May, 2016
+ * @brief Unit tests for gensvm_sv.c functions
+ */
+
+#include "minunit.h"
+#include "gensvm_sv.h"
+
+char *test_sv()
+{
+ struct GenModel *model = gensvm_init_model();
+
+ model->n = 5;
+ model->m = 3;
+ model->K = 3;
+ gensvm_allocate_model(model);
+
+ // for a support vector we need less than 2 elements per row larger
+ // than 1
+
+ // this is an sv
+ matrix_set(model->Q, model->K, 0, 0, 1.1);
+ matrix_set(model->Q, model->K, 0, 1, 0.0);
+ matrix_set(model->Q, model->K, 0, 2, 1.0);
+
+ // this is an sv
+ matrix_set(model->Q, model->K, 1, 0, 0.5);
+ matrix_set(model->Q, model->K, 1, 1, 0.5);
+ matrix_set(model->Q, model->K, 1, 2, 0.5);
+
+ // this is an sv
+ matrix_set(model->Q, model->K, 2, 0, -0.5);
+ matrix_set(model->Q, model->K, 2, 1, 0.5);
+ matrix_set(model->Q, model->K, 2, 2, -0.5);
+
+ // this is not an sv
+ matrix_set(model->Q, model->K, 3, 0, 1.5);
+ matrix_set(model->Q, model->K, 3, 1, 1.5);
+ matrix_set(model->Q, model->K, 3, 2, 0.5);
+
+ // this is not an sv
+ matrix_set(model->Q, model->K, 4, 0, 2.0);
+ matrix_set(model->Q, model->K, 4, 1, 2.0);
+ matrix_set(model->Q, model->K, 4, 2, 2.0);
+
+ mu_assert(gensvm_num_sv(model) == 3, "number of svs incorrect");
+
+ gensvm_free_model(model);
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_sv);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);