diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-09-20 16:46:41 +0200 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-09-20 16:46:41 +0200 |
| commit | 28497ecc371fd4c731892de91b07f635f1020452 (patch) | |
| tree | 6c1876f3e64e7d0d310f29d4f20fbf69d96f3ef4 /tests/src/test_gensvm_pred.c | |
| parent | Minor improvements and fixes (diff) | |
| download | gensvm-28497ecc371fd4c731892de91b07f635f1020452.tar.gz gensvm-28497ecc371fd4c731892de91b07f635f1020452.zip | |
Unit tests and corresponding data
Diffstat (limited to 'tests/src/test_gensvm_pred.c')
| -rw-r--r-- | tests/src/test_gensvm_pred.c | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/tests/src/test_gensvm_pred.c b/tests/src/test_gensvm_pred.c new file mode 100644 index 0000000..13f0e5a --- /dev/null +++ b/tests/src/test_gensvm_pred.c @@ -0,0 +1,165 @@ +/** + * @file test_gensvm_pred.c + * @author Gertjan van den Burg + * @date September, 2016 + * @brief Unit tests for gensvm_pred.c functions + */ +#include "minunit.h" +#include "gensvm_pred.h" + +/** + * This testcase is designed as follows: 12 evenly spaced points are plotted + * on the unit circle in the simplex space. These points are the ones for + * which we want to predict the class. To get these points, we need a Z and a + * V which map to these points. To get this, Z was equal to [1 Q] and V was + * equal to [0; R] where Q and R are from the reduced QR decomposition of the + * 12x2 matrix S which contains the points in simplex space. Here's the + * Matlab/Octave code to generate this data: + * + * n = 12; + * K = 3; + * S = [cos(1/12*pi+1/6*pi*[0:(n-1)])', sin(1/12*pi+1/6*pi*[0:(n-1)])']; + * [Q, R] = qr(S, '0'); + * Z = [ones(n, 1), Q]; + * V = [zeros(1, K-1); R]; + * + */ +char *test_gensvm_predict_labels() +{ + int n = 12; + int m = 2; + int K = 3; + + struct GenData *data = gensvm_init_data(); + struct GenModel *model = gensvm_init_model(); + + model->n = n; + model->m = m; + model->K = K; + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + + data->Z = Calloc(double, n*(m+1)); + data->y = Calloc(long, n); + + matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); + matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); + matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); + matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); + matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); + matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); + matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); + matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); + matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); + matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); + matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); + matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); + matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); + matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); + matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); + matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); + matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); + matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); + matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); + matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); + matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); + matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); + matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); + matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); + + gensvm_allocate_model(model); + + matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); + matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); + matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); + matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); + + // start test code + long *predy = Calloc(long, n); + gensvm_predict_labels(data, model, predy); + + mu_assert(predy[0] == 2, "Incorrect label at index 0"); + mu_assert(predy[1] == 3, "Incorrect label at index 1"); + mu_assert(predy[2] == 3, "Incorrect label at index 2"); + mu_assert(predy[3] == 3, "Incorrect label at index 3"); + mu_assert(predy[4] == 3, "Incorrect label at index 4"); + mu_assert(predy[5] == 1, "Incorrect label at index 5"); + mu_assert(predy[6] == 1, "Incorrect label at index 6"); + mu_assert(predy[7] == 1, "Incorrect label at index 7"); + mu_assert(predy[8] == 1, "Incorrect label at index 8"); + mu_assert(predy[9] == 2, "Incorrect label at index 9"); + mu_assert(predy[10] == 2, "Incorrect label at index 10"); + mu_assert(predy[11] == 2, "Incorrect label at index 11"); + + // end test code + gensvm_free_data(data); + gensvm_free_model(model); + free(predy); + + return NULL; +} + +char *test_gensvm_prediction_perf() +{ + int i, n = 8; + struct GenData *data = gensvm_init_data(); + data->n = n; + data->y = Calloc(long, n); + data->y[0] = 1; + data->y[1] = 1; + data->y[2] = 1; + data->y[3] = 1; + data->y[4] = 2; + data->y[5] = 2; + data->y[6] = 2; + data->y[7] = 3; + + long *y = Calloc(long, n); + for (i=0; i<n; i++) + y[i] = 1; + mu_assert(gensvm_prediction_perf(data, y) == 50.0, + "Incorrect first time."); + + for (i=0; i<n; i++) + y[i] = 2; + mu_assert(gensvm_prediction_perf(data, y) == 37.5, + "Incorrect second time."); + + for (i=0; i<n; i++) + y[i] = 3; + mu_assert(gensvm_prediction_perf(data, y) == 12.5, + "Incorrect third time."); + + free(y); + gensvm_free_data(data); + + return NULL; +} + +char *all_tests() +{ + mu_suite_start(); + mu_run_test(test_gensvm_predict_labels); + mu_run_test(test_gensvm_prediction_perf); + + return NULL; +} + +RUN_TESTS(all_tests); |
