/** * @file test_gensvm_predict.c * @author G.J.J. van den Burg * @date 2016-09-01 * @brief Unit tests for gensvm_predict.c functions * * @copyright Copyright 2016, G.J.J. van den Burg. This file is part of GenSVM. GenSVM is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. GenSVM is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with GenSVM. If not, see . */ #include "minunit.h" #include "gensvm_predict.h" /** * This testcase is designed as follows: 12 evenly spaced points are plotted * on the unit circle in the simplex space. These points are the ones for * which we want to predict the class. To get these points, we need a Z and a * V which map to these points. To get this, Z was equal to [1 Q] and V was * equal to [0; R] where Q and R are from the reduced QR decomposition of the * 12x2 matrix S which contains the points in simplex space. Here's the * Matlab/Octave code to generate this data: * * n = 12; * K = 3; * S = [cos(1/12*pi+1/6*pi*[0:(n-1)])', sin(1/12*pi+1/6*pi*[0:(n-1)])']; * [Q, R] = qr(S, '0'); * Z = [ones(n, 1), Q]; * V = [zeros(1, K-1); R]; * */ char *test_gensvm_predict_labels_dense() { int n = 12; int m = 2; int K = 3; struct GenData *data = gensvm_init_data(); struct GenModel *model = gensvm_init_model(); model->n = n; model->m = m; model->K = K; data->n = n; data->m = m; data->r = m; data->K = K; data->Z = Calloc(double, n*(m+1)); data->y = Calloc(long, n); matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); gensvm_allocate_model(model); matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); // start test code long *predy = Calloc(long, n); gensvm_predict_labels(data, model, predy); mu_assert(predy[0] == 2, "Incorrect label at index 0"); mu_assert(predy[1] == 3, "Incorrect label at index 1"); mu_assert(predy[2] == 3, "Incorrect label at index 2"); mu_assert(predy[3] == 3, "Incorrect label at index 3"); mu_assert(predy[4] == 3, "Incorrect label at index 4"); mu_assert(predy[5] == 1, "Incorrect label at index 5"); mu_assert(predy[6] == 1, "Incorrect label at index 6"); mu_assert(predy[7] == 1, "Incorrect label at index 7"); mu_assert(predy[8] == 1, "Incorrect label at index 8"); mu_assert(predy[9] == 2, "Incorrect label at index 9"); mu_assert(predy[10] == 2, "Incorrect label at index 10"); mu_assert(predy[11] == 2, "Incorrect label at index 11"); // end test code gensvm_free_data(data); gensvm_free_model(model); free(predy); return NULL; } char *test_gensvm_predict_labels_sparse() { int n = 12; int m = 2; int K = 3; struct GenData *data = gensvm_init_data(); struct GenModel *model = gensvm_init_model(); model->n = n; model->m = m; model->K = K; data->n = n; data->m = m; data->r = m; data->K = K; data->Z = Calloc(double, n*(m+1)); data->y = Calloc(long, n); matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); // convert Z to a sparse matrix to test the sparse functions data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1); free(data->Z); data->RAW = NULL; data->Z = NULL; gensvm_allocate_model(model); matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); // start test code long *predy = Calloc(long, n); gensvm_predict_labels(data, model, predy); mu_assert(predy[0] == 2, "Incorrect label at index 0"); mu_assert(predy[1] == 3, "Incorrect label at index 1"); mu_assert(predy[2] == 3, "Incorrect label at index 2"); mu_assert(predy[3] == 3, "Incorrect label at index 3"); mu_assert(predy[4] == 3, "Incorrect label at index 4"); mu_assert(predy[5] == 1, "Incorrect label at index 5"); mu_assert(predy[6] == 1, "Incorrect label at index 6"); mu_assert(predy[7] == 1, "Incorrect label at index 7"); mu_assert(predy[8] == 1, "Incorrect label at index 8"); mu_assert(predy[9] == 2, "Incorrect label at index 9"); mu_assert(predy[10] == 2, "Incorrect label at index 10"); mu_assert(predy[11] == 2, "Incorrect label at index 11"); // end test code gensvm_free_data(data); gensvm_free_model(model); free(predy); return NULL; } char *test_gensvm_prediction_perf() { int i, n = 8; struct GenData *data = gensvm_init_data(); data->n = n; data->y = Calloc(long, n); data->y[0] = 1; data->y[1] = 1; data->y[2] = 1; data->y[3] = 1; data->y[4] = 2; data->y[5] = 2; data->y[6] = 2; data->y[7] = 3; long *y = Calloc(long, n); for (i=0; i