aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/gensvm_optimize.h1
-rw-r--r--include/gensvm_pred.h5
-rw-r--r--include/gensvm_zv.h32
-rw-r--r--src/gensvm_optimize.c43
-rw-r--r--src/gensvm_pred.c19
-rw-r--r--src/gensvm_zv.c115
-rw-r--r--tests/src/test_gensvm_pred.c103
-rw-r--r--tests/src/test_gensvm_zv.c251
8 files changed, 508 insertions, 61 deletions
diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h
index 93f6676..39e17b7 100644
--- a/include/gensvm_optimize.h
+++ b/include/gensvm_optimize.h
@@ -16,6 +16,7 @@
#include "gensvm_sv.h"
#include "gensvm_simplex.h"
#include "gensvm_update.h"
+#include "gensvm_zv.h"
// function declarations
void gensvm_optimize(struct GenModel *model, struct GenData *data);
diff --git a/include/gensvm_pred.h b/include/gensvm_pred.h
index 56e16e8..12a59eb 100644
--- a/include/gensvm_pred.h
+++ b/include/gensvm_pred.h
@@ -15,10 +15,15 @@
// includes
#include "gensvm_kernel.h"
#include "gensvm_simplex.h"
+#include "gensvm_zv.h"
// function declarations
void gensvm_predict_labels(struct GenData *testdata,
struct GenModel *model, long *predy);
+void gensvm_predict_labels_dense(struct GenData *testdata,
+ struct GenModel *model, long *predy);
+void gensvm_predict_labels_sparse(struct GenData *testdata,
+ struct GenModel *model, long *predy);
double gensvm_prediction_perf(struct GenData *data, long *perdy);
#endif
diff --git a/include/gensvm_zv.h b/include/gensvm_zv.h
new file mode 100644
index 0000000..1134640
--- /dev/null
+++ b/include/gensvm_zv.h
@@ -0,0 +1,32 @@
+/**
+ * @file gensvm_zv.h
+ * @author Gertjan van den Burg
+ * @date 2016-10-17
+ * @brief Header file for gensvm_zv.c
+
+ * Copyright (C)
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License
+ as published by the Free Software Foundation; either version 2
+ of the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+
+ */
+
+#include "gensvm_base.h"
+
+void gensvm_calculate_ZV(struct GenModel *model, struct GenData *data,
+ double *ZV);
+void gensvm_calculate_ZV_sparse(struct GenModel *model,
+ struct GenData *data, double *ZV);
+void gensvm_calculate_ZV_dense(struct GenModel *model,
+ struct GenData *data, double *ZV);
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c
index e83fa47..300df40 100644
--- a/src/gensvm_optimize.c
+++ b/src/gensvm_optimize.c
@@ -257,10 +257,7 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
long n = model->n;
long K = model->K;
- if (data->spZ == NULL)
- gensvm_calculate_ZV_dense(model, data, ZV);
- else
- gensvm_calculate_ZV_sparse(model, data, ZV);
+ gensvm_calculate_ZV(model, data, ZV);
for (i=0; i<n; i++) {
for (j=0; j<K; j++) {
@@ -273,41 +270,3 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
}
}
-void gensvm_calculate_ZV_sparse(struct GenModel *model,
- struct GenData *data, double *ZV)
-{
- long i, j, jj, jj_start, jj_end, K,
- n_row = data->spZ->n_row;
- double z_ij;
-
- K = model->K;
-
- int *Zia = data->spZ->ia;
- int *Zja = data->spZ->ja;
- double *vals = data->spZ->values;
-
- for (i=0; i<n_row; i++) {
- jj_start = Zia[i];
- jj_end = Zia[i+1];
-
- for (jj=jj_start; jj<jj_end; jj++) {
- j = Zja[jj];
- z_ij = vals[jj];
-
- cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1,
- &ZV[i*(K-1)], 1);
- }
- }
-}
-
-void gensvm_calculate_ZV_dense(struct GenModel *model,
- struct GenData *data, double *ZV)
-{
- long n = model->n;
- long m = model->m;
- long K = model->K;
-
- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1,
- 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1);
-}
-
diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c
index afa1ab9..31b591c 100644
--- a/src/gensvm_pred.c
+++ b/src/gensvm_pred.c
@@ -30,13 +30,12 @@
void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
long *predy)
{
- long i, j, k, n, m, K, label;
+ long i, j, k, n, K, label;
double norm, min_dist,
*S = NULL,
*ZV = NULL;
n = testdata->n;
- m = testdata->r;
K = model->K;
// allocate necessary memory
@@ -47,21 +46,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
gensvm_simplex(model);
// Generate the simplex space vectors
- cblas_dgemm(
- CblasRowMajor,
- CblasNoTrans,
- CblasNoTrans,
- n,
- K-1,
- m+1,
- 1.0,
- testdata->Z,
- m+1,
- model->V,
- K-1,
- 0.0,
- ZV,
- K-1);
+ gensvm_calculate_ZV(model, testdata, ZV);
// Calculate the distance to each of the vertices of the simplex.
// The closest vertex defines the class label
diff --git a/src/gensvm_zv.c b/src/gensvm_zv.c
new file mode 100644
index 0000000..81a5354
--- /dev/null
+++ b/src/gensvm_zv.c
@@ -0,0 +1,115 @@
+/**
+ * @file gensvm_zv.c
+ * @author Gertjan van den Burg
+ * @date 2016-10-17
+ * @brief Functions for computing the ZV matrix product
+ *
+ * @details
+ * This file exists because the product Z*V of two matrices occurs both in the
+ * computation of the loss function and for predicting class labels. Moreover,
+ * a distinction has to be made between dense Z matrices and sparse Z
+ * matrices, hence a seperate file is warranted.
+
+ * Copyright (C)
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License
+ as published by the Free Software Foundation; either version 2
+ of the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+
+ */
+
+#include "gensvm_zv.h"
+
+/**
+ * @brief Wrapper around sparse/dense versions of this function
+ *
+ * @details
+ * This function tests if the data is stored in dense format or sparse format
+ * by testing if GenData::Z is NULL or not, and calls the corresponding
+ * version of this function accordingly.
+ *
+ * @sa
+ * gensvm_calculate_ZV_dense(), gensvm_calculate_ZV_sparse()
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV(struct GenModel *model, struct GenData *data,
+ double *ZV)
+{
+ if (data->Z == NULL)
+ gensvm_calculate_ZV_sparse(model, data, ZV);
+ else
+ gensvm_calculate_ZV_dense(model, data, ZV);
+}
+
+/**
+ * @brief Compute the product Z*V for when Z is a sparse matrix
+ *
+ * @details
+ * This is a simple sparse-dense matrix multiplication, which uses
+ * cblas_daxpy() for each nonzero element of Z, to compute Z*V.
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV_sparse(struct GenModel *model,
+ struct GenData *data, double *ZV)
+{
+ long i, j, jj, jj_start, jj_end, K,
+ n_row = data->spZ->n_row;
+ double z_ij;
+
+ K = model->K;
+
+ int *Zia = data->spZ->ia;
+ int *Zja = data->spZ->ja;
+ double *vals = data->spZ->values;
+
+ for (i=0; i<n_row; i++) {
+ jj_start = Zia[i];
+ jj_end = Zia[i+1];
+
+ for (jj=jj_start; jj<jj_end; jj++) {
+ j = Zja[jj];
+ z_ij = vals[jj];
+
+ cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1,
+ &ZV[i*(K-1)], 1);
+ }
+ }
+}
+
+/**
+ * @brief Compute the product Z*V for when Z is a dense matrix
+ *
+ * @details
+ * This function uses cblas_dgemm() to compute the matrix product between Z
+ * and V.
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV_dense(struct GenModel *model,
+ struct GenData *data, double *ZV)
+{
+ long n = model->n;
+ long m = model->m;
+ long K = model->K;
+
+ cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1,
+ 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1);
+}
diff --git a/tests/src/test_gensvm_pred.c b/tests/src/test_gensvm_pred.c
index 13f0e5a..155f0bf 100644
--- a/tests/src/test_gensvm_pred.c
+++ b/tests/src/test_gensvm_pred.c
@@ -24,7 +24,7 @@
* V = [zeros(1, K-1); R];
*
*/
-char *test_gensvm_predict_labels()
+char *test_gensvm_predict_labels_dense()
{
int n = 12;
int m = 2;
@@ -116,6 +116,104 @@ char *test_gensvm_predict_labels()
return NULL;
}
+char *test_gensvm_predict_labels_sparse()
+{
+ int n = 12;
+ int m = 2;
+ int K = 3;
+
+ struct GenData *data = gensvm_init_data();
+ struct GenModel *model = gensvm_init_model();
+
+ model->n = n;
+ model->m = m;
+ model->K = K;
+
+ data->n = n;
+ data->m = m;
+ data->r = m;
+ data->K = K;
+
+ data->Z = Calloc(double, n*(m+1));
+ data->y = Calloc(long, n);
+
+ matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065);
+ matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935);
+ matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129);
+ matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128);
+ matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937);
+ matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063);
+ matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935);
+ matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064);
+ matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129);
+ matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129);
+ matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064);
+ matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937);
+ matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065);
+ matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935);
+ matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130);
+ matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128);
+ matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939);
+ matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063);
+ matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934);
+ matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064);
+ matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126);
+ matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132);
+ matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000);
+ matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064);
+ matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939);
+
+ // convert Z to a sparse matrix to test the sparse functions
+ data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1);
+ free(data->Z);
+ data->RAW = NULL;
+ data->Z = NULL;
+
+ gensvm_allocate_model(model);
+
+ matrix_set(model->V, K-1, 0, 0, 0.0000000000000000);
+ matrix_set(model->V, K-1, 0, 1, 0.0000000000000000);
+ matrix_set(model->V, K-1, 1, 0, -2.4494897427831779);
+ matrix_set(model->V, K-1, 1, 1, -0.0000000000000002);
+ matrix_set(model->V, K-1, 2, 0, 0.0000000000000000);
+ matrix_set(model->V, K-1, 2, 1, -2.4494897427831783);
+
+ // start test code
+ long *predy = Calloc(long, n);
+ gensvm_predict_labels(data, model, predy);
+
+ mu_assert(predy[0] == 2, "Incorrect label at index 0");
+ mu_assert(predy[1] == 3, "Incorrect label at index 1");
+ mu_assert(predy[2] == 3, "Incorrect label at index 2");
+ mu_assert(predy[3] == 3, "Incorrect label at index 3");
+ mu_assert(predy[4] == 3, "Incorrect label at index 4");
+ mu_assert(predy[5] == 1, "Incorrect label at index 5");
+ mu_assert(predy[6] == 1, "Incorrect label at index 6");
+ mu_assert(predy[7] == 1, "Incorrect label at index 7");
+ mu_assert(predy[8] == 1, "Incorrect label at index 8");
+ mu_assert(predy[9] == 2, "Incorrect label at index 9");
+ mu_assert(predy[10] == 2, "Incorrect label at index 10");
+ mu_assert(predy[11] == 2, "Incorrect label at index 11");
+
+ // end test code
+ gensvm_free_data(data);
+ gensvm_free_model(model);
+ free(predy);
+
+ return NULL;
+}
+
char *test_gensvm_prediction_perf()
{
int i, n = 8;
@@ -156,7 +254,8 @@ char *test_gensvm_prediction_perf()
char *all_tests()
{
mu_suite_start();
- mu_run_test(test_gensvm_predict_labels);
+ mu_run_test(test_gensvm_predict_labels_dense);
+ mu_run_test(test_gensvm_predict_labels_sparse);
mu_run_test(test_gensvm_prediction_perf);
return NULL;
diff --git a/tests/src/test_gensvm_zv.c b/tests/src/test_gensvm_zv.c
new file mode 100644
index 0000000..21dd897
--- /dev/null
+++ b/tests/src/test_gensvm_zv.c
@@ -0,0 +1,251 @@
+/**
+ * @file test_gensvm_zv.c
+ * @author Gertjan van den Burg
+ * @date 2016-10-17
+ * @brief Unit tests for gensvm_zv.c functions
+
+ * Copyright (C)
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License
+ as published by the Free Software Foundation; either version 2
+ of the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+
+ */
+
+#include "minunit.h"
+#include "gensvm_zv.h"
+
+char *test_zv_dense()
+{
+ int n = 8,
+ m = 3,
+ K = 3;
+
+ struct GenModel *model = gensvm_init_model();
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ model->V = Calloc(double, (m+1)*(K-1));
+ matrix_set(model->V, model->K-1, 0, 0, 0.9025324416711976);
+ matrix_set(model->V, model->K-1, 0, 1, 0.9776784486541952);
+ matrix_set(model->V, model->K-1, 1, 0, 0.8336347240271171);
+ matrix_set(model->V, model->K-1, 1, 1, 0.1213543508830703);
+ matrix_set(model->V, model->K-1, 2, 0, 0.9401310852208050);
+ matrix_set(model->V, model->K-1, 2, 1, 0.7407478086613410);
+ matrix_set(model->V, model->K-1, 3, 0, 0.9053353815353901);
+ matrix_set(model->V, model->K-1, 3, 1, 0.8056059951641629);
+
+ struct GenData *data = gensvm_init_data();
+ data->n = n;
+ data->m = m;
+ data->K = K;
+ data->Z = Calloc(double, n*(m+1));
+ matrix_set(data->Z, data->m+1, 0, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 0, 1, 0.4787662921736276);
+ matrix_set(data->Z, data->m+1, 0, 2, 0.7983044792882817);
+ matrix_set(data->Z, data->m+1, 0, 3, 0.4273006962165122);
+ matrix_set(data->Z, data->m+1, 1, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 1, 1, 0.7160319769123790);
+ matrix_set(data->Z, data->m+1, 1, 2, 0.5233066338418962);
+ matrix_set(data->Z, data->m+1, 1, 3, 0.4063256860579537);
+ matrix_set(data->Z, data->m+1, 2, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 2, 1, 0.3735389652435536);
+ matrix_set(data->Z, data->m+1, 2, 2, 0.8156214578257802);
+ matrix_set(data->Z, data->m+1, 2, 3, 0.6928367712901857);
+ matrix_set(data->Z, data->m+1, 3, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 3, 1, 0.3694690105850765);
+ matrix_set(data->Z, data->m+1, 3, 2, 0.8539671806454873);
+ matrix_set(data->Z, data->m+1, 3, 3, 0.5455108033084728);
+ matrix_set(data->Z, data->m+1, 4, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 4, 1, 0.8802158533820680);
+ matrix_set(data->Z, data->m+1, 4, 2, 0.0690778177684403);
+ matrix_set(data->Z, data->m+1, 4, 3, 0.4513353324958240);
+ matrix_set(data->Z, data->m+1, 5, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 5, 1, 0.7752402729955837);
+ matrix_set(data->Z, data->m+1, 5, 2, 0.3941285577056867);
+ matrix_set(data->Z, data->m+1, 5, 3, 0.2921042477960945);
+ matrix_set(data->Z, data->m+1, 6, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 6, 1, 0.6139038657913901);
+ matrix_set(data->Z, data->m+1, 6, 2, 0.4529743309354828);
+ matrix_set(data->Z, data->m+1, 6, 3, 0.7295983135133345);
+ matrix_set(data->Z, data->m+1, 7, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 7, 1, 0.7663625136928905);
+ matrix_set(data->Z, data->m+1, 7, 2, 0.3845759571625976);
+ matrix_set(data->Z, data->m+1, 7, 3, 0.2291505633226144);
+
+ // start test code //
+ double *ZV = Calloc(double, n*(K-1));
+ double eps = 1e-14;
+ gensvm_calculate_ZV(model, data, ZV);
+
+ mu_assert(fabs(matrix_get(ZV, K-1, 0, 0) - 2.4390099428102818) < eps,
+ "Incorrect ZV at 0, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 0, 1) - 1.9713571175527906) < eps,
+ "Incorrect ZV at 0, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 1, 0) - 2.3592794147310747) < eps,
+ "Incorrect ZV at 1, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 1, 1) - 1.7795486953777246) < eps,
+ "Incorrect ZV at 1, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 2, 0) - 2.6079682228282564) < eps,
+ "Incorrect ZV at 2, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 2, 1) - 2.1853322915140310) < eps,
+ "Incorrect ZV at 2, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 3, 0) - 2.5072459618750060) < eps,
+ "Incorrect ZV at 3, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 3, 1) - 2.0945562119091297) < eps,
+ "Incorrect ZV at 3, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 4, 0) - 2.1098629909184887) < eps,
+ "Incorrect ZV at 4, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 4, 1) - 1.4992641640054902) < eps,
+ "Incorrect ZV at 4, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 5, 0) - 2.1837844720035213) < eps,
+ "Incorrect ZV at 5, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 5, 1) - 1.5990280274507829) < eps,
+ "Incorrect ZV at 5, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 6, 0) - 2.5006904382610986) < eps,
+ "Incorrect ZV at 6, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 6, 1) - 1.9754868722402175) < eps,
+ "Incorrect ZV at 6, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 7, 0) - 2.1104087689101294) < eps,
+ "Incorrect ZV at 7, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 7, 1) - 1.5401587391844891) < eps,
+ "Incorrect ZV at 7, 1");
+
+ free(ZV);
+ // end test code //
+ gensvm_free_data(data);
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_zv_sparse()
+{
+ int n = 8,
+ m = 3,
+ K = 3;
+
+ struct GenModel *model = gensvm_init_model();
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ model->V = Calloc(double, (m+1)*(K-1));
+ matrix_set(model->V, model->K-1, 0, 0, 0.9025324416711976);
+ matrix_set(model->V, model->K-1, 0, 1, 0.9776784486541952);
+ matrix_set(model->V, model->K-1, 1, 0, 0.8336347240271171);
+ matrix_set(model->V, model->K-1, 1, 1, 0.1213543508830703);
+ matrix_set(model->V, model->K-1, 2, 0, 0.9401310852208050);
+ matrix_set(model->V, model->K-1, 2, 1, 0.7407478086613410);
+ matrix_set(model->V, model->K-1, 3, 0, 0.9053353815353901);
+ matrix_set(model->V, model->K-1, 3, 1, 0.8056059951641629);
+
+ struct GenData *data = gensvm_init_data();
+ data->n = n;
+ data->m = m;
+ data->K = K;
+ data->Z = Calloc(double, n*(m+1));
+ matrix_set(data->Z, data->m+1, 0, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 0, 1, 0.4787662921736276);
+ matrix_set(data->Z, data->m+1, 0, 2, 0.7983044792882817);
+ matrix_set(data->Z, data->m+1, 0, 3, 0.4273006962165122);
+ matrix_set(data->Z, data->m+1, 1, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 1, 1, 0.7160319769123790);
+ matrix_set(data->Z, data->m+1, 1, 2, 0.5233066338418962);
+ matrix_set(data->Z, data->m+1, 1, 3, 0.4063256860579537);
+ matrix_set(data->Z, data->m+1, 2, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 2, 1, 0.3735389652435536);
+ matrix_set(data->Z, data->m+1, 2, 2, 0.8156214578257802);
+ matrix_set(data->Z, data->m+1, 2, 3, 0.6928367712901857);
+ matrix_set(data->Z, data->m+1, 3, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 3, 1, 0.3694690105850765);
+ matrix_set(data->Z, data->m+1, 3, 2, 0.8539671806454873);
+ matrix_set(data->Z, data->m+1, 3, 3, 0.5455108033084728);
+ matrix_set(data->Z, data->m+1, 4, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 4, 1, 0.8802158533820680);
+ matrix_set(data->Z, data->m+1, 4, 2, 0.0690778177684403);
+ matrix_set(data->Z, data->m+1, 4, 3, 0.4513353324958240);
+ matrix_set(data->Z, data->m+1, 5, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 5, 1, 0.7752402729955837);
+ matrix_set(data->Z, data->m+1, 5, 2, 0.3941285577056867);
+ matrix_set(data->Z, data->m+1, 5, 3, 0.2921042477960945);
+ matrix_set(data->Z, data->m+1, 6, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 6, 1, 0.6139038657913901);
+ matrix_set(data->Z, data->m+1, 6, 2, 0.4529743309354828);
+ matrix_set(data->Z, data->m+1, 6, 3, 0.7295983135133345);
+ matrix_set(data->Z, data->m+1, 7, 0, 1.0000000000000000);
+ matrix_set(data->Z, data->m+1, 7, 1, 0.7663625136928905);
+ matrix_set(data->Z, data->m+1, 7, 2, 0.3845759571625976);
+ matrix_set(data->Z, data->m+1, 7, 3, 0.2291505633226144);
+
+ // convert Z to sparse matrix
+ data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1);
+ free(data->Z);
+ data->RAW = NULL;
+ data->Z = NULL;
+
+ // start test code //
+ double *ZV = Calloc(double, n*(K-1));
+ double eps = 1e-14;
+ gensvm_calculate_ZV(model, data, ZV);
+
+ mu_assert(fabs(matrix_get(ZV, K-1, 0, 0) - 2.4390099428102818) < eps,
+ "Incorrect ZV at 0, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 0, 1) - 1.9713571175527906) < eps,
+ "Incorrect ZV at 0, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 1, 0) - 2.3592794147310747) < eps,
+ "Incorrect ZV at 1, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 1, 1) - 1.7795486953777246) < eps,
+ "Incorrect ZV at 1, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 2, 0) - 2.6079682228282564) < eps,
+ "Incorrect ZV at 2, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 2, 1) - 2.1853322915140310) < eps,
+ "Incorrect ZV at 2, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 3, 0) - 2.5072459618750060) < eps,
+ "Incorrect ZV at 3, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 3, 1) - 2.0945562119091297) < eps,
+ "Incorrect ZV at 3, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 4, 0) - 2.1098629909184887) < eps,
+ "Incorrect ZV at 4, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 4, 1) - 1.4992641640054902) < eps,
+ "Incorrect ZV at 4, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 5, 0) - 2.1837844720035213) < eps,
+ "Incorrect ZV at 5, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 5, 1) - 1.5990280274507829) < eps,
+ "Incorrect ZV at 5, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 6, 0) - 2.5006904382610986) < eps,
+ "Incorrect ZV at 6, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 6, 1) - 1.9754868722402175) < eps,
+ "Incorrect ZV at 6, 1");
+ mu_assert(fabs(matrix_get(ZV, K-1, 7, 0) - 2.1104087689101294) < eps,
+ "Incorrect ZV at 7, 0");
+ mu_assert(fabs(matrix_get(ZV, K-1, 7, 1) - 1.5401587391844891) < eps,
+ "Incorrect ZV at 7, 1");
+
+ free(ZV);
+ // end test code //
+ gensvm_free_data(data);
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *all_tests()
+{
+ mu_suite_start();
+ mu_run_test(test_zv_dense);
+ mu_run_test(test_zv_sparse);
+
+ return NULL;
+}
+
+RUN_TESTS(all_tests);