From cc40bb91ce4509a177a729e58cb78afe5ca0dfb0 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Mon, 17 Oct 2016 13:44:54 +0200 Subject: refactor gensvm_pred to gensvm_predict --- include/gensvm_gridsearch.h | 2 +- include/gensvm_pred.h | 29 ----- include/gensvm_predict.h | 29 +++++ src/GenSVMtraintest.c | 2 +- src/gensvm_pred.c | 99 --------------- src/gensvm_predict.c | 99 +++++++++++++++ tests/src/test_gensvm_pred.c | 264 ---------------------------------------- tests/src/test_gensvm_predict.c | 264 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 394 insertions(+), 394 deletions(-) delete mode 100644 include/gensvm_pred.h create mode 100644 include/gensvm_predict.h delete mode 100644 src/gensvm_pred.c create mode 100644 src/gensvm_predict.c delete mode 100644 tests/src/test_gensvm_pred.c create mode 100644 tests/src/test_gensvm_predict.c diff --git a/include/gensvm_gridsearch.h b/include/gensvm_gridsearch.h index dcd9b93..f59f520 100644 --- a/include/gensvm_gridsearch.h +++ b/include/gensvm_gridsearch.h @@ -20,7 +20,7 @@ #include "gensvm_init.h" #include "gensvm_grid.h" #include "gensvm_optimize.h" -#include "gensvm_pred.h" +#include "gensvm_predict.h" #include "gensvm_queue.h" #include "gensvm_timer.h" diff --git a/include/gensvm_pred.h b/include/gensvm_pred.h deleted file mode 100644 index 12a59eb..0000000 --- a/include/gensvm_pred.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * @file gensvm_pred.h - * @author Gertjan van den Burg - * @date August, 2013 - * @brief Header file for gensvm_pred.c - * - * @details - * Contains function declarations for prediction functions. - * - */ - -#ifndef GENSVM_PRED_H -#define GENSVM_PRED_H - -// includes -#include "gensvm_kernel.h" -#include "gensvm_simplex.h" -#include "gensvm_zv.h" - -// function declarations -void gensvm_predict_labels(struct GenData *testdata, - struct GenModel *model, long *predy); -void gensvm_predict_labels_dense(struct GenData *testdata, - struct GenModel *model, long *predy); -void gensvm_predict_labels_sparse(struct GenData *testdata, - struct GenModel *model, long *predy); -double gensvm_prediction_perf(struct GenData *data, long *perdy); - -#endif diff --git a/include/gensvm_predict.h b/include/gensvm_predict.h new file mode 100644 index 0000000..a60ddd9 --- /dev/null +++ b/include/gensvm_predict.h @@ -0,0 +1,29 @@ +/** + * @file gensvm_pred.h + * @author Gertjan van den Burg + * @date August, 2013 + * @brief Header file for gensvm_pred.c + * + * @details + * Contains function declarations for prediction functions. + * + */ + +#ifndef GENSVM_PREDICT_H +#define GENSVM_PREDICT_H + +// includes +#include "gensvm_kernel.h" +#include "gensvm_simplex.h" +#include "gensvm_zv.h" + +// function declarations +void gensvm_predict_labels(struct GenData *testdata, + struct GenModel *model, long *predy); +void gensvm_predict_labels_dense(struct GenData *testdata, + struct GenModel *model, long *predy); +void gensvm_predict_labels_sparse(struct GenData *testdata, + struct GenModel *model, long *predy); +double gensvm_prediction_perf(struct GenData *data, long *perdy); + +#endif diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index f70fa97..777d8e5 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -13,7 +13,7 @@ #include "gensvm_cmdarg.h" #include "gensvm_io.h" #include "gensvm_train.h" -#include "gensvm_pred.h" +#include "gensvm_predict.h" #define MINARGS 2 diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c deleted file mode 100644 index 31b591c..0000000 --- a/src/gensvm_pred.c +++ /dev/null @@ -1,99 +0,0 @@ -/** - * @file gensvm_pred.c - * @author Gertjan van den Burg - * @date August 9, 2013 - * @brief Main functions for predicting class labels.. - * - * @details - * This file contains functions for predicting the class labels of instances - * and a function for calculating the predictive performance (hitrate) of - * a prediction given true class labels. - * - */ - -#include "gensvm_pred.h" - -/** - * @brief Predict class labels of data given and output in predy - * - * @details - * The labels are predicted by mapping each instance in data to the - * simplex space using the matrix V in the given model. Next, for each - * instance the nearest simplex vertex is determined using an Euclidean - * norm. The nearest simplex vertex determines the predicted class label, - * which is recorded in predy. - * - * @param[in] testdata GenData to predict labels for - * @param[in] model GenModel with optimized V - * @param[out] predy pre-allocated vector to record predictions in - */ -void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, - long *predy) -{ - long i, j, k, n, K, label; - double norm, min_dist, - *S = NULL, - *ZV = NULL; - - n = testdata->n; - K = model->K; - - // allocate necessary memory - S = Calloc(double, K-1); - ZV = Calloc(double, n*(K-1)); - - // Generate the simplex matrix - gensvm_simplex(model); - - // Generate the simplex space vectors - gensvm_calculate_ZV(model, testdata, ZV); - - // Calculate the distance to each of the vertices of the simplex. - // The closest vertex defines the class label - for (i=0; iU, K-1, j, k); - } - norm = cblas_dnrm2(K-1, S, 1); - if (norm < min_dist) { - label = j+1; - min_dist = norm; - } - } - predy[i] = label; - } - - free(ZV); - free(S); -} - -/** - * @brief Calculate the predictive performance (percentage correct) - * - * @details - * The predictive performance is calculated by simply counting the number - * of correctly classified samples and dividing by the total number of - * samples, multiplying by 100. - * - * @param[in] data the GenData dataset with known labels - * @param[in] predy the predicted class labels - * - * @returns percentage correctly classified. - */ -double gensvm_prediction_perf(struct GenData *data, long *predy) -{ - long i, correct = 0; - double performance; - - for (i=0; in; i++) - if (data->y[i] == predy[i]) - correct++; - - performance = ((double) correct)/((double) data->n)* 100.0; - - return performance; -} diff --git a/src/gensvm_predict.c b/src/gensvm_predict.c new file mode 100644 index 0000000..1112e55 --- /dev/null +++ b/src/gensvm_predict.c @@ -0,0 +1,99 @@ +/** + * @file gensvm_pred.c + * @author Gertjan van den Burg + * @date August 9, 2013 + * @brief Main functions for predicting class labels.. + * + * @details + * This file contains functions for predicting the class labels of instances + * and a function for calculating the predictive performance (hitrate) of + * a prediction given true class labels. + * + */ + +#include "gensvm_predict.h" + +/** + * @brief Predict class labels of data given and output in predy + * + * @details + * The labels are predicted by mapping each instance in data to the + * simplex space using the matrix V in the given model. Next, for each + * instance the nearest simplex vertex is determined using an Euclidean + * norm. The nearest simplex vertex determines the predicted class label, + * which is recorded in predy. + * + * @param[in] testdata GenData to predict labels for + * @param[in] model GenModel with optimized V + * @param[out] predy pre-allocated vector to record predictions in + */ +void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, + long *predy) +{ + long i, j, k, n, K, label; + double norm, min_dist, + *S = NULL, + *ZV = NULL; + + n = testdata->n; + K = model->K; + + // allocate necessary memory + S = Calloc(double, K-1); + ZV = Calloc(double, n*(K-1)); + + // Generate the simplex matrix + gensvm_simplex(model); + + // Generate the simplex space vectors + gensvm_calculate_ZV(model, testdata, ZV); + + // Calculate the distance to each of the vertices of the simplex. + // The closest vertex defines the class label + for (i=0; iU, K-1, j, k); + } + norm = cblas_dnrm2(K-1, S, 1); + if (norm < min_dist) { + label = j+1; + min_dist = norm; + } + } + predy[i] = label; + } + + free(ZV); + free(S); +} + +/** + * @brief Calculate the predictive performance (percentage correct) + * + * @details + * The predictive performance is calculated by simply counting the number + * of correctly classified samples and dividing by the total number of + * samples, multiplying by 100. + * + * @param[in] data the GenData dataset with known labels + * @param[in] predy the predicted class labels + * + * @returns percentage correctly classified. + */ +double gensvm_prediction_perf(struct GenData *data, long *predy) +{ + long i, correct = 0; + double performance; + + for (i=0; in; i++) + if (data->y[i] == predy[i]) + correct++; + + performance = ((double) correct)/((double) data->n)* 100.0; + + return performance; +} diff --git a/tests/src/test_gensvm_pred.c b/tests/src/test_gensvm_pred.c deleted file mode 100644 index 155f0bf..0000000 --- a/tests/src/test_gensvm_pred.c +++ /dev/null @@ -1,264 +0,0 @@ -/** - * @file test_gensvm_pred.c - * @author Gertjan van den Burg - * @date September, 2016 - * @brief Unit tests for gensvm_pred.c functions - */ -#include "minunit.h" -#include "gensvm_pred.h" - -/** - * This testcase is designed as follows: 12 evenly spaced points are plotted - * on the unit circle in the simplex space. These points are the ones for - * which we want to predict the class. To get these points, we need a Z and a - * V which map to these points. To get this, Z was equal to [1 Q] and V was - * equal to [0; R] where Q and R are from the reduced QR decomposition of the - * 12x2 matrix S which contains the points in simplex space. Here's the - * Matlab/Octave code to generate this data: - * - * n = 12; - * K = 3; - * S = [cos(1/12*pi+1/6*pi*[0:(n-1)])', sin(1/12*pi+1/6*pi*[0:(n-1)])']; - * [Q, R] = qr(S, '0'); - * Z = [ones(n, 1), Q]; - * V = [zeros(1, K-1); R]; - * - */ -char *test_gensvm_predict_labels_dense() -{ - int n = 12; - int m = 2; - int K = 3; - - struct GenData *data = gensvm_init_data(); - struct GenModel *model = gensvm_init_model(); - - model->n = n; - model->m = m; - model->K = K; - - data->n = n; - data->m = m; - data->r = m; - data->K = K; - - data->Z = Calloc(double, n*(m+1)); - data->y = Calloc(long, n); - - matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); - matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); - matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); - matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); - matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); - matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); - matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); - matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); - matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); - matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); - matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); - matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); - matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); - matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); - matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); - matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); - matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); - matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); - matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); - matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); - matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); - matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); - matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); - matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); - - gensvm_allocate_model(model); - - matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); - matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); - matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); - matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); - matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); - matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); - - // start test code - long *predy = Calloc(long, n); - gensvm_predict_labels(data, model, predy); - - mu_assert(predy[0] == 2, "Incorrect label at index 0"); - mu_assert(predy[1] == 3, "Incorrect label at index 1"); - mu_assert(predy[2] == 3, "Incorrect label at index 2"); - mu_assert(predy[3] == 3, "Incorrect label at index 3"); - mu_assert(predy[4] == 3, "Incorrect label at index 4"); - mu_assert(predy[5] == 1, "Incorrect label at index 5"); - mu_assert(predy[6] == 1, "Incorrect label at index 6"); - mu_assert(predy[7] == 1, "Incorrect label at index 7"); - mu_assert(predy[8] == 1, "Incorrect label at index 8"); - mu_assert(predy[9] == 2, "Incorrect label at index 9"); - mu_assert(predy[10] == 2, "Incorrect label at index 10"); - mu_assert(predy[11] == 2, "Incorrect label at index 11"); - - // end test code - gensvm_free_data(data); - gensvm_free_model(model); - free(predy); - - return NULL; -} - -char *test_gensvm_predict_labels_sparse() -{ - int n = 12; - int m = 2; - int K = 3; - - struct GenData *data = gensvm_init_data(); - struct GenModel *model = gensvm_init_model(); - - model->n = n; - model->m = m; - model->K = K; - - data->n = n; - data->m = m; - data->r = m; - data->K = K; - - data->Z = Calloc(double, n*(m+1)); - data->y = Calloc(long, n); - - matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); - matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); - matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); - matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); - matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); - matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); - matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); - matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); - matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); - matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); - matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); - matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); - matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); - matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); - matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); - matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); - matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); - matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); - matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); - matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); - matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); - matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); - matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); - matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); - matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); - - // convert Z to a sparse matrix to test the sparse functions - data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1); - free(data->Z); - data->RAW = NULL; - data->Z = NULL; - - gensvm_allocate_model(model); - - matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); - matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); - matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); - matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); - matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); - matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); - - // start test code - long *predy = Calloc(long, n); - gensvm_predict_labels(data, model, predy); - - mu_assert(predy[0] == 2, "Incorrect label at index 0"); - mu_assert(predy[1] == 3, "Incorrect label at index 1"); - mu_assert(predy[2] == 3, "Incorrect label at index 2"); - mu_assert(predy[3] == 3, "Incorrect label at index 3"); - mu_assert(predy[4] == 3, "Incorrect label at index 4"); - mu_assert(predy[5] == 1, "Incorrect label at index 5"); - mu_assert(predy[6] == 1, "Incorrect label at index 6"); - mu_assert(predy[7] == 1, "Incorrect label at index 7"); - mu_assert(predy[8] == 1, "Incorrect label at index 8"); - mu_assert(predy[9] == 2, "Incorrect label at index 9"); - mu_assert(predy[10] == 2, "Incorrect label at index 10"); - mu_assert(predy[11] == 2, "Incorrect label at index 11"); - - // end test code - gensvm_free_data(data); - gensvm_free_model(model); - free(predy); - - return NULL; -} - -char *test_gensvm_prediction_perf() -{ - int i, n = 8; - struct GenData *data = gensvm_init_data(); - data->n = n; - data->y = Calloc(long, n); - data->y[0] = 1; - data->y[1] = 1; - data->y[2] = 1; - data->y[3] = 1; - data->y[4] = 2; - data->y[5] = 2; - data->y[6] = 2; - data->y[7] = 3; - - long *y = Calloc(long, n); - for (i=0; in = n; + model->m = m; + model->K = K; + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + + data->Z = Calloc(double, n*(m+1)); + data->y = Calloc(long, n); + + matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); + matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); + matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); + matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); + matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); + matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); + matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); + matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); + matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); + matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); + matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); + matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); + matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); + matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); + matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); + matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); + matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); + matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); + matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); + matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); + matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); + matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); + matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); + matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); + + gensvm_allocate_model(model); + + matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); + matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); + matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); + matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); + + // start test code + long *predy = Calloc(long, n); + gensvm_predict_labels(data, model, predy); + + mu_assert(predy[0] == 2, "Incorrect label at index 0"); + mu_assert(predy[1] == 3, "Incorrect label at index 1"); + mu_assert(predy[2] == 3, "Incorrect label at index 2"); + mu_assert(predy[3] == 3, "Incorrect label at index 3"); + mu_assert(predy[4] == 3, "Incorrect label at index 4"); + mu_assert(predy[5] == 1, "Incorrect label at index 5"); + mu_assert(predy[6] == 1, "Incorrect label at index 6"); + mu_assert(predy[7] == 1, "Incorrect label at index 7"); + mu_assert(predy[8] == 1, "Incorrect label at index 8"); + mu_assert(predy[9] == 2, "Incorrect label at index 9"); + mu_assert(predy[10] == 2, "Incorrect label at index 10"); + mu_assert(predy[11] == 2, "Incorrect label at index 11"); + + // end test code + gensvm_free_data(data); + gensvm_free_model(model); + free(predy); + + return NULL; +} + +char *test_gensvm_predict_labels_sparse() +{ + int n = 12; + int m = 2; + int K = 3; + + struct GenData *data = gensvm_init_data(); + struct GenModel *model = gensvm_init_model(); + + model->n = n; + model->m = m; + model->K = K; + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + + data->Z = Calloc(double, n*(m+1)); + data->y = Calloc(long, n); + + matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); + matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); + matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); + matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); + matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); + matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); + matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); + matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); + matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); + matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); + matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); + matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); + matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); + matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); + matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); + matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); + matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); + matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); + matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); + matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); + matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); + matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); + matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); + matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); + + // convert Z to a sparse matrix to test the sparse functions + data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1); + free(data->Z); + data->RAW = NULL; + data->Z = NULL; + + gensvm_allocate_model(model); + + matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); + matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); + matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); + matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); + + // start test code + long *predy = Calloc(long, n); + gensvm_predict_labels(data, model, predy); + + mu_assert(predy[0] == 2, "Incorrect label at index 0"); + mu_assert(predy[1] == 3, "Incorrect label at index 1"); + mu_assert(predy[2] == 3, "Incorrect label at index 2"); + mu_assert(predy[3] == 3, "Incorrect label at index 3"); + mu_assert(predy[4] == 3, "Incorrect label at index 4"); + mu_assert(predy[5] == 1, "Incorrect label at index 5"); + mu_assert(predy[6] == 1, "Incorrect label at index 6"); + mu_assert(predy[7] == 1, "Incorrect label at index 7"); + mu_assert(predy[8] == 1, "Incorrect label at index 8"); + mu_assert(predy[9] == 2, "Incorrect label at index 9"); + mu_assert(predy[10] == 2, "Incorrect label at index 10"); + mu_assert(predy[11] == 2, "Incorrect label at index 11"); + + // end test code + gensvm_free_data(data); + gensvm_free_model(model); + free(predy); + + return NULL; +} + +char *test_gensvm_prediction_perf() +{ + int i, n = 8; + struct GenData *data = gensvm_init_data(); + data->n = n; + data->y = Calloc(long, n); + data->y[0] = 1; + data->y[1] = 1; + data->y[2] = 1; + data->y[3] = 1; + data->y[4] = 2; + data->y[5] = 2; + data->y[6] = 2; + data->y[7] = 3; + + long *y = Calloc(long, n); + for (i=0; i