aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/gensvm_base.h2
-rw-r--r--include/gensvm_simplex.h5
-rw-r--r--src/gensvm_base.c5
-rw-r--r--src/gensvm_optimize.c58
-rw-r--r--src/gensvm_pred.c8
-rw-r--r--src/gensvm_simplex.c47
-rw-r--r--tests/src/test_gensvm_optimize.c210
-rw-r--r--tests/src/test_gensvm_simplex.c182
8 files changed, 236 insertions, 281 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h
index 390a9b6..efeaa4d 100644
--- a/include/gensvm_base.h
+++ b/include/gensvm_base.h
@@ -82,7 +82,7 @@ struct GenModel {
double *U;
///< simplex matrix
double *UU;
- ///< 3D simplex difference matrix
+ ///< simplex difference matrix
double *Q;
///< error matrix
double *H;
diff --git a/include/gensvm_simplex.h b/include/gensvm_simplex.h
index d774afa..4f8a475 100644
--- a/include/gensvm_simplex.h
+++ b/include/gensvm_simplex.h
@@ -10,9 +10,10 @@
#define GENSVM_SIMPLEX_H
// includes
-#include "gensvm_globals.h"
+#include "gensvm_base.h"
// forward declarations
-void gensvm_simplex(long K, double *U);
+void gensvm_simplex(struct GenModel *model);
+void gensvm_simplex_diff(struct GenModel *model);
#endif
diff --git a/src/gensvm_base.c b/src/gensvm_base.c
index 2c30530..c77e704 100644
--- a/src/gensvm_base.c
+++ b/src/gensvm_base.c
@@ -119,7 +119,7 @@ void gensvm_allocate_model(struct GenModel *model)
model->V = Calloc(double, (m+1)*(K-1));
model->Vbar = Calloc(double, (m+1)*(K-1));
model->U = Calloc(double, K*(K-1));
- model->UU = Calloc(double, n*K*(K-1));
+ model->UU = Calloc(double, K*K*(K-1));
model->Q = Calloc(double, n*K);
model->H = Calloc(double, n*K);
model->rho = Calloc(double, n);
@@ -145,9 +145,6 @@ void gensvm_reallocate_model(struct GenModel *model, long n, long m)
if (model->n == n && model->m == m)
return;
if (model->n != n) {
- model->UU = Realloc(model->UU, double, n*K*(K-1));
- Memset(model->UU, double, n*K*(K-1));
-
model->Q = Realloc(model->Q, double, n*K);
Memset(model->Q, double, n*K);
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c
index d08f917..a16512f 100644
--- a/src/gensvm_optimize.c
+++ b/src/gensvm_optimize.c
@@ -66,9 +66,8 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data)
note("\tepsilon = %g\n", model->epsilon);
note("\n");
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
gensvm_simplex(model);
+ gensvm_simplex_diff(model);
L = gensvm_get_loss(model, data, ZV);
Lbar = L + 2.0*model->epsilon*L;
@@ -576,39 +575,6 @@ void gensvm_get_update(struct GenModel *model, struct GenData *data,
}
}
-/**
- * @brief Generate the simplex difference matrix
- *
- * @details
- * The simplex difference matrix is a 3D matrix which is constructed
- * as follows. For each instance i, the difference vectors between the row of
- * the simplex matrix corresponding to the class label of instance i and the
- * other rows of the simplex matrix are calculated. These difference vectors
- * are stored in a matrix, which is one horizontal slice of the 3D matrix.
- *
- * We use the indices i, j, k for the three dimensions n, K-1, K of UU. Then
- * the i,j,k -th element of UU is equal to U(y[i]-1, j) - U(k, j).
- *
- * @param[in,out] model the corresponding GenModel
- * @param[in] data the corresponding GenData
- *
- */
-void gensvm_simplex_diff(struct GenModel *model, struct GenData *data)
-{
- long i, j, k;
- double value;
-
- long n = model->n;
- long K = model->K;
-
- for (i=0; i<n; i++) {
- for (j=0; j<K-1; j++) {
- for (k=0; k<K; k++) {
- value = matrix_get(model->U, K-1,
- data->y[i]-1, j);
- value -= matrix_get(model->U, K-1, k, j);
- matrix3_set(model->UU, K-1, K, i, j, k, value);
- }
}
}
}
@@ -689,21 +655,17 @@ void gensvm_calculate_huber(struct GenModel *model)
* allocated. In addition, the matrix ZV is calculated here. It is assigned
* to a pre-allocated block of memory, which is passed to this function.
*
- * @todo
- * Transform UU to small UU then fix that here
- *
* @param[in,out] model the corresponding GenModel
* @param[in] data the corresponding GenData
* @param[in,out] ZV a pointer to a memory block for ZV. On exit
* this block is updated with the new ZV matrix
* calculated with GenModel::V
- *
*/
void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
double *ZV)
{
- long i, j, k;
- double zv, value;
+ long i, j;
+ double q, *uu_row;
long n = model->n;
long m = model->m;
@@ -725,15 +687,13 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
ZV,
K-1);
- Memset(model->Q, double, n*K);
for (i=0; i<n; i++) {
- for (j=0; j<K-1; j++) {
- zv = matrix_get(ZV, K-1, i, j);
- for (k=0; k<K; k++) {
- value = zv * matrix3_get(model->UU, K-1, K, i,
- j, k);
- matrix_add(model->Q, K, i, k, value);
- }
+ for (j=0; j<K; j++) {
+ if (j == (data->y[i]-1))
+ continue;
+ uu_row = &model->UU[((data->y[i]-1)*K+j)*(K-1)];
+ q = cblas_ddot(K-1, &ZV[i*(K-1)], 1, uu_row, 1);
+ matrix_set(model->Q, K, i, j, q);
}
}
}
diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c
index 8a9a43e..43b27cc 100644
--- a/src/gensvm_pred.c
+++ b/src/gensvm_pred.c
@@ -31,7 +31,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
long *predy)
{
long i, j, k, n, m, K, label;
- double norm, min_dist, *S, *ZV, *U;
+ double norm, min_dist, *S, *ZV;
n = testdata->n;
m = testdata->r;
@@ -40,10 +40,9 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
// allocate necessary memory
S = Calloc(double, K-1);
ZV = Calloc(double, n*(K-1));
- U = Calloc(double, K*(K-1));
// Generate the simplex matrix
- gensvm_simplex(K, U);
+ gensvm_simplex(model);
// Generate the simplex space vectors
cblas_dgemm(
@@ -70,7 +69,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
for (j=0; j<K; j++) {
for (k=0; k<K-1; k++) {
S[k] = matrix_get(ZV, K-1, i, k) -
- matrix_get(U, K-1, j, k);
+ matrix_get(model->U, K-1, j, k);
}
norm = cblas_dnrm2(K-1, S, 1);
if (norm < min_dist) {
@@ -82,7 +81,6 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
}
free(ZV);
- free(U);
free(S);
}
diff --git a/src/gensvm_simplex.c b/src/gensvm_simplex.c
index 1fd5f14..a704d85 100644
--- a/src/gensvm_simplex.c
+++ b/src/gensvm_simplex.c
@@ -25,21 +25,58 @@
* @param[in] K number of classes
* @param[in,out] U simplex matrix of size K * (K-1)
*/
-void gensvm_simplex(long K, double *U)
+void gensvm_simplex(struct GenModel *model)
{
- long i, j;
+ long i, j, K = model->K;
+
for (i=0; i<K; i++) {
for (j=0; j<K-1; j++) {
if (i <= j) {
- matrix_set(U, K-1, i, j,
+ matrix_set(model->U, K-1, i, j,
-1.0/sqrt(2.0*(j+1)*(j+2)));
} else if (i == j+1) {
- matrix_set(U, K-1, i, j,
+ matrix_set(model->U, K-1, i, j,
sqrt((j+1)/(2.0*(j+2))));
} else {
- matrix_set(U, K-1, i, j, 0.0);
+ matrix_set(model->U, K-1, i, j, 0.0);
}
}
}
}
+/**
+ * @brief Generate the simplex difference matrix
+ *
+ * @details
+ * The simplex difference matrix is a 2D block matrix which is constructed
+ * as follows. For each class i, we have a block of K rows and K-1 columns.
+ * Each row in the block for class i contains a row vector with the difference
+ * of the simplex matrix, U(i, :) - U(j, :).
+ *
+ * In the paper the notation @f$\boldsymbol{\delta}_{kj}'@f$ is used for the
+ * difference vector of @f$\textbf{u}_k' - \textbf{u}_j'@f$, where
+ * @f$\textbf{u}_k'@f$ corresponds to row k of @f$\textbf{U}@f$. Due to the
+ * indexing in the paper being 1-based and C indexing is 0 based, the vector
+ * @f$\boldsymbol{\delta}_{kj}'@f$ corresponds to the row (k-1)*K+(j-1) in the
+ * UU matrix generated here.
+ *
+ * @param[in,out] model the corresponding GenModel
+ *
+ */
+void gensvm_simplex_diff(struct GenModel *model)
+{
+ long i, j, l, K = model->K;
+ double value;
+
+ // UU is a 2D block matrix, where block i has the differences:
+ // U(i, :) - U(j, :) for all j
+ for (i=0; i<K; i++) {
+ for (j=0; j<K; j++) {
+ for (l=0; l<K-1; l++) {
+ value = matrix_get(model->U, K-1, i, l);
+ value -= matrix_get(model->U, K-1, j, l);
+ matrix_set(model->UU, K-1, i*K+j, l, value);
+ }
+ }
+ }
+}
diff --git a/tests/src/test_gensvm_optimize.c b/tests/src/test_gensvm_optimize.c
index f0947a4..6a6571d 100644
--- a/tests/src/test_gensvm_optimize.c
+++ b/tests/src/test_gensvm_optimize.c
@@ -196,8 +196,8 @@ char *test_gensvm_get_loss_1()
gensvm_allocate_model(model);
gensvm_initialize_weights(data, model);
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
+ gensvm_simplex(model);
+ gensvm_simplex_diff(model);
matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683);
matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701);
@@ -294,8 +294,8 @@ char *test_gensvm_get_loss_2()
gensvm_allocate_model(model);
gensvm_initialize_weights(data, model);
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
+ gensvm_simplex(model);
+ gensvm_simplex_diff(model);
matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683);
matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701);
@@ -365,19 +365,19 @@ char *test_gensvm_calculate_omega()
matrix_set(model->H, model->K, 4, 2, 0.8184193969741095);
// start test code //
- mu_assert(fabs(gensvm_calculate_omega(model, 0) -
+ mu_assert(fabs(gensvm_calculate_omega(model, data, 0) -
0.7394076262220608) < 1e-14,
"Incorrect omega at 0");
- mu_assert(fabs(gensvm_calculate_omega(model, 1) -
+ mu_assert(fabs(gensvm_calculate_omega(model, data, 1) -
0.7294526264247443) < 1e-14,
"Incorrect omega at 1");
- mu_assert(fabs(gensvm_calculate_omega(model, 2) -
+ mu_assert(fabs(gensvm_calculate_omega(model, data, 2) -
0.6802499471888741) < 1e-14,
"Incorrect omega at 2");
- mu_assert(fabs(gensvm_calculate_omega(model, 3) -
+ mu_assert(fabs(gensvm_calculate_omega(model, data, 3) -
0.6886792032441273) < 1e-14,
"Incorrect omega at 3");
- mu_assert(fabs(gensvm_calculate_omega(model, 4) -
+ mu_assert(fabs(gensvm_calculate_omega(model, data, 4) -
0.8695329737474283) < 1e-14,
"Incorrect omega at 4");
@@ -593,8 +593,8 @@ char *test_gensvm_update_B()
model->kappa = 0.5;
gensvm_allocate_model(model);
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
+ gensvm_simplex(model);
+ gensvm_simplex_diff(model);
gensvm_initialize_weights(data, model);
// start test code //
@@ -639,6 +639,7 @@ char *test_gensvm_get_Avalue_update_B()
mu_test_missing();
return NULL;
}
+*/
char *test_gensvm_get_update()
{
@@ -709,8 +710,8 @@ char *test_gensvm_get_update()
// initialize matrices
gensvm_allocate_model(model);
gensvm_initialize_weights(data, model);
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
+ gensvm_simplex(model);
+ gensvm_simplex_diff(model);
// initialize V
matrix_set(model->V, model->K-1, 0, 0, -0.7593642121025029);
@@ -729,7 +730,7 @@ char *test_gensvm_get_update()
double *ZAZVT = Calloc(double, (m+1)*(K-1));
// these need to be prepared for the update call
- gensvm_calculate_errors(model, data);
+ gensvm_calculate_errors(model, data, ZV);
gensvm_calculate_huber(model);
// run the actual update call
@@ -774,182 +775,8 @@ char *test_gensvm_get_update()
return NULL;
}
-char *test_gensvm_simplex_diff()
-{
- struct GenData *data = gensvm_init_data();
- struct GenModel *model = gensvm_init_model();
-
- int n = 8,
- m = 3,
- K = 3;
- model->n = n;
- model->m = m;
- model->K = K;
- data->n = n;
- data->m = m;
- data->K = K;
-
- data->y = Calloc(long, n);
-
- gensvm_allocate_model(model);
- gensvm_simplex(model->K, model->U);
-
- data->y[0] = 2;
- data->y[1] = 1;
- data->y[2] = 3;
- data->y[3] = 2;
- data->y[4] = 3;
- data->y[5] = 3;
- data->y[6] = 1;
- data->y[7] = 2;
-
- // start test code //
- gensvm_simplex_diff(model, data);
+ free(ZV);
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 0) -
- 1.0000000000000000) < 1e-14,
- "Incorrect value at UU(0, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(0, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 2) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(0, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(0, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(0, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 2) -
- -0.8660254037844388) < 1e-14,
- "Incorrect value at UU(0, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(1, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 1) -
- -1.0000000000000000) < 1e-14,
- "Incorrect value at UU(1, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 2) -
- -0.5000000000000000) < 1e-14,
- "Incorrect value at UU(1, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(1, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(1, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 2) -
- -0.8660254037844388) < 1e-14,
- "Incorrect value at UU(1, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 0) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(2, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 1) -
- -0.5000000000000000) < 1e-14,
- "Incorrect value at UU(2, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(2, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 0) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(2, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 1) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(2, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(2, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 0) -
- 1.0000000000000000) < 1e-14,
- "Incorrect value at UU(3, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(3, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 2) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(3, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(3, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(3, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 2) -
- -0.8660254037844388) < 1e-14,
- "Incorrect value at UU(3, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 0) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(4, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 1) -
- -0.5000000000000000) < 1e-14,
- "Incorrect value at UU(4, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(4, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 0) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(4, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 1) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(4, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(4, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 0) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(5, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 1) -
- -0.5000000000000000) < 1e-14,
- "Incorrect value at UU(5, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(5, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 0) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(5, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 1) -
- 0.8660254037844388) < 1e-14,
- "Incorrect value at UU(5, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 2) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(5, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(6, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 1) -
- -1.0000000000000000) < 1e-14,
- "Incorrect value at UU(6, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 2) -
- -0.5000000000000000) < 1e-14,
- "Incorrect value at UU(6, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(6, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(6, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 2) -
- -0.8660254037844388) < 1e-14,
- "Incorrect value at UU(6, 1, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 0) -
- 1.0000000000000000) < 1e-14,
- "Incorrect value at UU(7, 0, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(7, 0, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 2) -
- 0.5000000000000000) < 1e-14,
- "Incorrect value at UU(7, 0, 2)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 0) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(7, 1, 0)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 1) -
- 0.0000000000000000) < 1e-14,
- "Incorrect value at UU(7, 1, 1)");
- mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 2) -
- -0.8660254037844388) < 1e-14,
- "Incorrect value at UU(7, 1, 2)");
// end test code //
gensvm_free_model(model);
@@ -1021,8 +848,8 @@ char *test_gensvm_calculate_errors()
model->m = m;
model->K = K;
gensvm_allocate_model(model);
- gensvm_simplex(model->K, model->U);
- gensvm_simplex_diff(model, data);
+ gensvm_simplex(model);
+ gensvm_simplex_diff(model);
matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683);
matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701);
@@ -1629,7 +1456,6 @@ char *all_tests()
mu_run_test(test_gensvm_get_Avalue_update_B);
mu_run_test(test_gensvm_get_update);
- mu_run_test(test_gensvm_simplex_diff);
mu_run_test(test_gensvm_calculate_errors);
mu_run_test(test_gensvm_calculate_huber);
mu_run_test(test_gensvm_step_doubling);
diff --git a/tests/src/test_gensvm_simplex.c b/tests/src/test_gensvm_simplex.c
index c13c4ea..c217ca0 100644
--- a/tests/src/test_gensvm_simplex.c
+++ b/tests/src/test_gensvm_simplex.c
@@ -10,48 +10,183 @@
char *test_simplex_1()
{
- double *U = Calloc(double, 2*1);
+ struct GenModel *model = gensvm_init_model();
+ model->U = Calloc(double, 2*1);
+ model->K = 2;
- gensvm_simplex(2, U);
+ gensvm_simplex(model);
- mu_assert(matrix_get(U, 1, 0, 0) == -0.5, "U(0, 0) incorrect.");
- mu_assert(matrix_get(U, 1, 1, 0) == 0.5, "U(1, 0) incorrect.");
+ mu_assert(fabs(matrix_get(model->U, 1, 0, 0) - -0.5) < 1e-14,
+ "U(0, 0) incorrect.");
+ mu_assert(fabs(matrix_get(model->U, 1, 1, 0) - 0.5) < 1e-14,
+ "U(1, 0) incorrect.");
- free(U);
+ gensvm_free_model(model);
return NULL;
}
char *test_simplex_2()
{
- double *U = Calloc(double, 4*3);
-
- gensvm_simplex(4, U);
-
- mu_assert(matrix_get(U, 3, 0, 0) == -0.5, "U(0, 0) incorrect.");
- mu_assert(matrix_get(U, 3, 1, 0) == 0.5, "U(1, 0) incorrect.");
- mu_assert(matrix_get(U, 3, 2, 0) == 0.0, "U(2, 0) incorrect.");
- mu_assert(matrix_get(U, 3, 3, 0) == 0.0, "U(3, 0) incorrect.");
-
- mu_assert(fabs(matrix_get(U, 3, 0, 1) - -0.5/sqrt(3)) < 1e-14,
+ struct GenModel *model = gensvm_init_model();
+ model->U = Calloc(double, 4*3);
+ model->K = 4;
+
+ gensvm_simplex(model);
+
+ mu_assert(fabs(matrix_get(model->U, 3, 0, 0) - -0.5) < 1e-14,
+ "U(0, 0) incorrect.");
+ mu_assert(fabs(matrix_get(model->U, 3, 1, 0) - 0.5) < 1e-14,
+ "U(1, 0) incorrect.");
+ mu_assert(fabs(matrix_get(model->U, 3, 2, 0) - 0.0) < 1e-14,
+ "U(2, 0) incorrect.");
+ mu_assert(fabs(matrix_get(model->U, 3, 3, 0) - 0.0) < 1e-14,
+ "U(3, 0) incorrect.");
+
+ mu_assert(fabs(matrix_get(model->U, 3, 0, 1) - -0.5/sqrt(3)) < 1e-14,
"U(0, 1) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 1, 1) - -0.5/sqrt(3)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 1, 1) - -0.5/sqrt(3)) < 1e-14,
"U(1, 1) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 2, 1) - 1.0/sqrt(3)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 2, 1) - 1.0/sqrt(3)) < 1e-14,
"U(2, 1) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 3, 1) - 0.0) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 3, 1) - 0.0) < 1e-14,
"U(3, 1) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 0, 2) - -1.0/sqrt(24)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 0, 2) - -1.0/sqrt(24)) < 1e-14,
"U(0, 2) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 1, 2) - -1.0/sqrt(24)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 1, 2) - -1.0/sqrt(24)) < 1e-14,
"U(1, 2) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 2, 2) - -1.0/sqrt(24)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 2, 2) - -1.0/sqrt(24)) < 1e-14,
"U(2, 2) incorrect.");
- mu_assert(fabs(matrix_get(U, 3, 3, 2) - 3.0/sqrt(24)) < 1e-14,
+ mu_assert(fabs(matrix_get(model->U, 3, 3, 2) - 3.0/sqrt(24)) < 1e-14,
"U(3, 2) incorrect.");
- free(U);
+ gensvm_free_model(model);
+
+ return NULL;
+}
+
+char *test_gensvm_simplex_diff()
+{
+ struct GenData *data = gensvm_init_data();
+ struct GenModel *model = gensvm_init_model();
+
+ int n = 8,
+ m = 3,
+ K = 4;
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ data->n = n;
+ data->m = m;
+ data->K = K;
+
+ gensvm_allocate_model(model);
+ gensvm_simplex(model);
+
+ // start test code //
+ gensvm_simplex_diff(model);
+ mu_assert(fabs(model->UU[0] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 0");
+ mu_assert(fabs(model->UU[1] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 1");
+ mu_assert(fabs(model->UU[2] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 2");
+ mu_assert(fabs(model->UU[3] - -1.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 3");
+ mu_assert(fabs(model->UU[4] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 4");
+ mu_assert(fabs(model->UU[5] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 5");
+ mu_assert(fabs(model->UU[6] - -0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 6");
+ mu_assert(fabs(model->UU[7] - -0.8660254037844388) < 1e-14,
+ "Incorrect value of model->UU at 7");
+ mu_assert(fabs(model->UU[8] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 8");
+ mu_assert(fabs(model->UU[9] - -0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 9");
+ mu_assert(fabs(model->UU[10] - -0.2886751345948129) < 1e-14,
+ "Incorrect value of model->UU at 10");
+ mu_assert(fabs(model->UU[11] - -0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 11");
+ mu_assert(fabs(model->UU[12] - 1.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 12");
+ mu_assert(fabs(model->UU[13] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 13");
+ mu_assert(fabs(model->UU[14] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 14");
+ mu_assert(fabs(model->UU[15] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 15");
+ mu_assert(fabs(model->UU[16] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 16");
+ mu_assert(fabs(model->UU[17] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 17");
+ mu_assert(fabs(model->UU[18] - 0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 18");
+ mu_assert(fabs(model->UU[19] - -0.8660254037844388) < 1e-14,
+ "Incorrect value of model->UU at 19");
+ mu_assert(fabs(model->UU[20] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 20");
+ mu_assert(fabs(model->UU[21] - 0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 21");
+ mu_assert(fabs(model->UU[22] - -0.2886751345948129) < 1e-14,
+ "Incorrect value of model->UU at 22");
+ mu_assert(fabs(model->UU[23] - -0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 23");
+ mu_assert(fabs(model->UU[24] - 0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 24");
+ mu_assert(fabs(model->UU[25] - 0.8660254037844388) < 1e-14,
+ "Incorrect value of model->UU at 25");
+ mu_assert(fabs(model->UU[26] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 26");
+ mu_assert(fabs(model->UU[27] - -0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 27");
+ mu_assert(fabs(model->UU[28] - 0.8660254037844388) < 1e-14,
+ "Incorrect value of model->UU at 28");
+ mu_assert(fabs(model->UU[29] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 29");
+ mu_assert(fabs(model->UU[30] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 30");
+ mu_assert(fabs(model->UU[31] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 31");
+ mu_assert(fabs(model->UU[32] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 32");
+ mu_assert(fabs(model->UU[33] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 33");
+ mu_assert(fabs(model->UU[34] - 0.5773502691896258) < 1e-14,
+ "Incorrect value of model->UU at 34");
+ mu_assert(fabs(model->UU[35] - -0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 35");
+ mu_assert(fabs(model->UU[36] - 0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 36");
+ mu_assert(fabs(model->UU[37] - 0.2886751345948129) < 1e-14,
+ "Incorrect value of model->UU at 37");
+ mu_assert(fabs(model->UU[38] - 0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 38");
+ mu_assert(fabs(model->UU[39] - -0.5000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 39");
+ mu_assert(fabs(model->UU[40] - 0.2886751345948129) < 1e-14,
+ "Incorrect value of model->UU at 40");
+ mu_assert(fabs(model->UU[41] - 0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 41");
+ mu_assert(fabs(model->UU[42] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 42");
+ mu_assert(fabs(model->UU[43] - -0.5773502691896258) < 1e-14,
+ "Incorrect value of model->UU at 43");
+ mu_assert(fabs(model->UU[44] - 0.8164965809277261) < 1e-14,
+ "Incorrect value of model->UU at 44");
+ mu_assert(fabs(model->UU[45] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 45");
+ mu_assert(fabs(model->UU[46] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 46");
+ mu_assert(fabs(model->UU[47] - 0.0000000000000000) < 1e-14,
+ "Incorrect value of model->UU at 47");
+
+ // end test code //
+
+ gensvm_free_model(model);
+ gensvm_free_data(data);
return NULL;
}
@@ -61,6 +196,7 @@ char *all_tests()
mu_suite_start();
mu_run_test(test_simplex_1);
mu_run_test(test_simplex_2);
+ mu_run_test(test_gensvm_simplex_diff);
return NULL;
}