diff options
| -rw-r--r-- | include/gensvm_base.h | 2 | ||||
| -rw-r--r-- | include/gensvm_simplex.h | 5 | ||||
| -rw-r--r-- | src/gensvm_base.c | 5 | ||||
| -rw-r--r-- | src/gensvm_optimize.c | 58 | ||||
| -rw-r--r-- | src/gensvm_pred.c | 8 | ||||
| -rw-r--r-- | src/gensvm_simplex.c | 47 | ||||
| -rw-r--r-- | tests/src/test_gensvm_optimize.c | 210 | ||||
| -rw-r--r-- | tests/src/test_gensvm_simplex.c | 182 |
8 files changed, 236 insertions, 281 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h index 390a9b6..efeaa4d 100644 --- a/include/gensvm_base.h +++ b/include/gensvm_base.h @@ -82,7 +82,7 @@ struct GenModel { double *U; ///< simplex matrix double *UU; - ///< 3D simplex difference matrix + ///< simplex difference matrix double *Q; ///< error matrix double *H; diff --git a/include/gensvm_simplex.h b/include/gensvm_simplex.h index d774afa..4f8a475 100644 --- a/include/gensvm_simplex.h +++ b/include/gensvm_simplex.h @@ -10,9 +10,10 @@ #define GENSVM_SIMPLEX_H // includes -#include "gensvm_globals.h" +#include "gensvm_base.h" // forward declarations -void gensvm_simplex(long K, double *U); +void gensvm_simplex(struct GenModel *model); +void gensvm_simplex_diff(struct GenModel *model); #endif diff --git a/src/gensvm_base.c b/src/gensvm_base.c index 2c30530..c77e704 100644 --- a/src/gensvm_base.c +++ b/src/gensvm_base.c @@ -119,7 +119,7 @@ void gensvm_allocate_model(struct GenModel *model) model->V = Calloc(double, (m+1)*(K-1)); model->Vbar = Calloc(double, (m+1)*(K-1)); model->U = Calloc(double, K*(K-1)); - model->UU = Calloc(double, n*K*(K-1)); + model->UU = Calloc(double, K*K*(K-1)); model->Q = Calloc(double, n*K); model->H = Calloc(double, n*K); model->rho = Calloc(double, n); @@ -145,9 +145,6 @@ void gensvm_reallocate_model(struct GenModel *model, long n, long m) if (model->n == n && model->m == m) return; if (model->n != n) { - model->UU = Realloc(model->UU, double, n*K*(K-1)); - Memset(model->UU, double, n*K*(K-1)); - model->Q = Realloc(model->Q, double, n*K); Memset(model->Q, double, n*K); diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c index d08f917..a16512f 100644 --- a/src/gensvm_optimize.c +++ b/src/gensvm_optimize.c @@ -66,9 +66,8 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data) note("\tepsilon = %g\n", model->epsilon); note("\n"); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); gensvm_simplex(model); + gensvm_simplex_diff(model); L = gensvm_get_loss(model, data, ZV); Lbar = L + 2.0*model->epsilon*L; @@ -576,39 +575,6 @@ void gensvm_get_update(struct GenModel *model, struct GenData *data, } } -/** - * @brief Generate the simplex difference matrix - * - * @details - * The simplex difference matrix is a 3D matrix which is constructed - * as follows. For each instance i, the difference vectors between the row of - * the simplex matrix corresponding to the class label of instance i and the - * other rows of the simplex matrix are calculated. These difference vectors - * are stored in a matrix, which is one horizontal slice of the 3D matrix. - * - * We use the indices i, j, k for the three dimensions n, K-1, K of UU. Then - * the i,j,k -th element of UU is equal to U(y[i]-1, j) - U(k, j). - * - * @param[in,out] model the corresponding GenModel - * @param[in] data the corresponding GenData - * - */ -void gensvm_simplex_diff(struct GenModel *model, struct GenData *data) -{ - long i, j, k; - double value; - - long n = model->n; - long K = model->K; - - for (i=0; i<n; i++) { - for (j=0; j<K-1; j++) { - for (k=0; k<K; k++) { - value = matrix_get(model->U, K-1, - data->y[i]-1, j); - value -= matrix_get(model->U, K-1, k, j); - matrix3_set(model->UU, K-1, K, i, j, k, value); - } } } } @@ -689,21 +655,17 @@ void gensvm_calculate_huber(struct GenModel *model) * allocated. In addition, the matrix ZV is calculated here. It is assigned * to a pre-allocated block of memory, which is passed to this function. * - * @todo - * Transform UU to small UU then fix that here - * * @param[in,out] model the corresponding GenModel * @param[in] data the corresponding GenData * @param[in,out] ZV a pointer to a memory block for ZV. On exit * this block is updated with the new ZV matrix * calculated with GenModel::V - * */ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data, double *ZV) { - long i, j, k; - double zv, value; + long i, j; + double q, *uu_row; long n = model->n; long m = model->m; @@ -725,15 +687,13 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data, ZV, K-1); - Memset(model->Q, double, n*K); for (i=0; i<n; i++) { - for (j=0; j<K-1; j++) { - zv = matrix_get(ZV, K-1, i, j); - for (k=0; k<K; k++) { - value = zv * matrix3_get(model->UU, K-1, K, i, - j, k); - matrix_add(model->Q, K, i, k, value); - } + for (j=0; j<K; j++) { + if (j == (data->y[i]-1)) + continue; + uu_row = &model->UU[((data->y[i]-1)*K+j)*(K-1)]; + q = cblas_ddot(K-1, &ZV[i*(K-1)], 1, uu_row, 1); + matrix_set(model->Q, K, i, j, q); } } } diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c index 8a9a43e..43b27cc 100644 --- a/src/gensvm_pred.c +++ b/src/gensvm_pred.c @@ -31,7 +31,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, long *predy) { long i, j, k, n, m, K, label; - double norm, min_dist, *S, *ZV, *U; + double norm, min_dist, *S, *ZV; n = testdata->n; m = testdata->r; @@ -40,10 +40,9 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, // allocate necessary memory S = Calloc(double, K-1); ZV = Calloc(double, n*(K-1)); - U = Calloc(double, K*(K-1)); // Generate the simplex matrix - gensvm_simplex(K, U); + gensvm_simplex(model); // Generate the simplex space vectors cblas_dgemm( @@ -70,7 +69,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, for (j=0; j<K; j++) { for (k=0; k<K-1; k++) { S[k] = matrix_get(ZV, K-1, i, k) - - matrix_get(U, K-1, j, k); + matrix_get(model->U, K-1, j, k); } norm = cblas_dnrm2(K-1, S, 1); if (norm < min_dist) { @@ -82,7 +81,6 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, } free(ZV); - free(U); free(S); } diff --git a/src/gensvm_simplex.c b/src/gensvm_simplex.c index 1fd5f14..a704d85 100644 --- a/src/gensvm_simplex.c +++ b/src/gensvm_simplex.c @@ -25,21 +25,58 @@ * @param[in] K number of classes * @param[in,out] U simplex matrix of size K * (K-1) */ -void gensvm_simplex(long K, double *U) +void gensvm_simplex(struct GenModel *model) { - long i, j; + long i, j, K = model->K; + for (i=0; i<K; i++) { for (j=0; j<K-1; j++) { if (i <= j) { - matrix_set(U, K-1, i, j, + matrix_set(model->U, K-1, i, j, -1.0/sqrt(2.0*(j+1)*(j+2))); } else if (i == j+1) { - matrix_set(U, K-1, i, j, + matrix_set(model->U, K-1, i, j, sqrt((j+1)/(2.0*(j+2)))); } else { - matrix_set(U, K-1, i, j, 0.0); + matrix_set(model->U, K-1, i, j, 0.0); } } } } +/** + * @brief Generate the simplex difference matrix + * + * @details + * The simplex difference matrix is a 2D block matrix which is constructed + * as follows. For each class i, we have a block of K rows and K-1 columns. + * Each row in the block for class i contains a row vector with the difference + * of the simplex matrix, U(i, :) - U(j, :). + * + * In the paper the notation @f$\boldsymbol{\delta}_{kj}'@f$ is used for the + * difference vector of @f$\textbf{u}_k' - \textbf{u}_j'@f$, where + * @f$\textbf{u}_k'@f$ corresponds to row k of @f$\textbf{U}@f$. Due to the + * indexing in the paper being 1-based and C indexing is 0 based, the vector + * @f$\boldsymbol{\delta}_{kj}'@f$ corresponds to the row (k-1)*K+(j-1) in the + * UU matrix generated here. + * + * @param[in,out] model the corresponding GenModel + * + */ +void gensvm_simplex_diff(struct GenModel *model) +{ + long i, j, l, K = model->K; + double value; + + // UU is a 2D block matrix, where block i has the differences: + // U(i, :) - U(j, :) for all j + for (i=0; i<K; i++) { + for (j=0; j<K; j++) { + for (l=0; l<K-1; l++) { + value = matrix_get(model->U, K-1, i, l); + value -= matrix_get(model->U, K-1, j, l); + matrix_set(model->UU, K-1, i*K+j, l, value); + } + } + } +} diff --git a/tests/src/test_gensvm_optimize.c b/tests/src/test_gensvm_optimize.c index f0947a4..6a6571d 100644 --- a/tests/src/test_gensvm_optimize.c +++ b/tests/src/test_gensvm_optimize.c @@ -196,8 +196,8 @@ char *test_gensvm_get_loss_1() gensvm_allocate_model(model); gensvm_initialize_weights(data, model); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); + gensvm_simplex(model); + gensvm_simplex_diff(model); matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683); matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701); @@ -294,8 +294,8 @@ char *test_gensvm_get_loss_2() gensvm_allocate_model(model); gensvm_initialize_weights(data, model); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); + gensvm_simplex(model); + gensvm_simplex_diff(model); matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683); matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701); @@ -365,19 +365,19 @@ char *test_gensvm_calculate_omega() matrix_set(model->H, model->K, 4, 2, 0.8184193969741095); // start test code // - mu_assert(fabs(gensvm_calculate_omega(model, 0) - + mu_assert(fabs(gensvm_calculate_omega(model, data, 0) - 0.7394076262220608) < 1e-14, "Incorrect omega at 0"); - mu_assert(fabs(gensvm_calculate_omega(model, 1) - + mu_assert(fabs(gensvm_calculate_omega(model, data, 1) - 0.7294526264247443) < 1e-14, "Incorrect omega at 1"); - mu_assert(fabs(gensvm_calculate_omega(model, 2) - + mu_assert(fabs(gensvm_calculate_omega(model, data, 2) - 0.6802499471888741) < 1e-14, "Incorrect omega at 2"); - mu_assert(fabs(gensvm_calculate_omega(model, 3) - + mu_assert(fabs(gensvm_calculate_omega(model, data, 3) - 0.6886792032441273) < 1e-14, "Incorrect omega at 3"); - mu_assert(fabs(gensvm_calculate_omega(model, 4) - + mu_assert(fabs(gensvm_calculate_omega(model, data, 4) - 0.8695329737474283) < 1e-14, "Incorrect omega at 4"); @@ -593,8 +593,8 @@ char *test_gensvm_update_B() model->kappa = 0.5; gensvm_allocate_model(model); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); + gensvm_simplex(model); + gensvm_simplex_diff(model); gensvm_initialize_weights(data, model); // start test code // @@ -639,6 +639,7 @@ char *test_gensvm_get_Avalue_update_B() mu_test_missing(); return NULL; } +*/ char *test_gensvm_get_update() { @@ -709,8 +710,8 @@ char *test_gensvm_get_update() // initialize matrices gensvm_allocate_model(model); gensvm_initialize_weights(data, model); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); + gensvm_simplex(model); + gensvm_simplex_diff(model); // initialize V matrix_set(model->V, model->K-1, 0, 0, -0.7593642121025029); @@ -729,7 +730,7 @@ char *test_gensvm_get_update() double *ZAZVT = Calloc(double, (m+1)*(K-1)); // these need to be prepared for the update call - gensvm_calculate_errors(model, data); + gensvm_calculate_errors(model, data, ZV); gensvm_calculate_huber(model); // run the actual update call @@ -774,182 +775,8 @@ char *test_gensvm_get_update() return NULL; } -char *test_gensvm_simplex_diff() -{ - struct GenData *data = gensvm_init_data(); - struct GenModel *model = gensvm_init_model(); - - int n = 8, - m = 3, - K = 3; - model->n = n; - model->m = m; - model->K = K; - data->n = n; - data->m = m; - data->K = K; - - data->y = Calloc(long, n); - - gensvm_allocate_model(model); - gensvm_simplex(model->K, model->U); - - data->y[0] = 2; - data->y[1] = 1; - data->y[2] = 3; - data->y[3] = 2; - data->y[4] = 3; - data->y[5] = 3; - data->y[6] = 1; - data->y[7] = 2; - - // start test code // - gensvm_simplex_diff(model, data); + free(ZV); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 0) - - 1.0000000000000000) < 1e-14, - "Incorrect value at UU(0, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(0, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 0, 2) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(0, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(0, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(0, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 0, 1, 2) - - -0.8660254037844388) < 1e-14, - "Incorrect value at UU(0, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(1, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 1) - - -1.0000000000000000) < 1e-14, - "Incorrect value at UU(1, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 0, 2) - - -0.5000000000000000) < 1e-14, - "Incorrect value at UU(1, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(1, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(1, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 1, 1, 2) - - -0.8660254037844388) < 1e-14, - "Incorrect value at UU(1, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 0) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(2, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 1) - - -0.5000000000000000) < 1e-14, - "Incorrect value at UU(2, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 0, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(2, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 0) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(2, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 1) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(2, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 2, 1, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(2, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 0) - - 1.0000000000000000) < 1e-14, - "Incorrect value at UU(3, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(3, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 0, 2) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(3, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(3, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(3, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 3, 1, 2) - - -0.8660254037844388) < 1e-14, - "Incorrect value at UU(3, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 0) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(4, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 1) - - -0.5000000000000000) < 1e-14, - "Incorrect value at UU(4, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 0, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(4, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 0) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(4, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 1) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(4, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 4, 1, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(4, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 0) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(5, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 1) - - -0.5000000000000000) < 1e-14, - "Incorrect value at UU(5, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 0, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(5, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 0) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(5, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 1) - - 0.8660254037844388) < 1e-14, - "Incorrect value at UU(5, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 5, 1, 2) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(5, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(6, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 1) - - -1.0000000000000000) < 1e-14, - "Incorrect value at UU(6, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 0, 2) - - -0.5000000000000000) < 1e-14, - "Incorrect value at UU(6, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(6, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(6, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 6, 1, 2) - - -0.8660254037844388) < 1e-14, - "Incorrect value at UU(6, 1, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 0) - - 1.0000000000000000) < 1e-14, - "Incorrect value at UU(7, 0, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(7, 0, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 0, 2) - - 0.5000000000000000) < 1e-14, - "Incorrect value at UU(7, 0, 2)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 0) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(7, 1, 0)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 1) - - 0.0000000000000000) < 1e-14, - "Incorrect value at UU(7, 1, 1)"); - mu_assert(fabs(matrix3_get(model->UU, K-1, K, 7, 1, 2) - - -0.8660254037844388) < 1e-14, - "Incorrect value at UU(7, 1, 2)"); // end test code // gensvm_free_model(model); @@ -1021,8 +848,8 @@ char *test_gensvm_calculate_errors() model->m = m; model->K = K; gensvm_allocate_model(model); - gensvm_simplex(model->K, model->U); - gensvm_simplex_diff(model, data); + gensvm_simplex(model); + gensvm_simplex_diff(model); matrix_set(model->V, model->K-1, 0, 0, 0.6019309459245683); matrix_set(model->V, model->K-1, 0, 1, 0.0063825200426701); @@ -1629,7 +1456,6 @@ char *all_tests() mu_run_test(test_gensvm_get_Avalue_update_B); mu_run_test(test_gensvm_get_update); - mu_run_test(test_gensvm_simplex_diff); mu_run_test(test_gensvm_calculate_errors); mu_run_test(test_gensvm_calculate_huber); mu_run_test(test_gensvm_step_doubling); diff --git a/tests/src/test_gensvm_simplex.c b/tests/src/test_gensvm_simplex.c index c13c4ea..c217ca0 100644 --- a/tests/src/test_gensvm_simplex.c +++ b/tests/src/test_gensvm_simplex.c @@ -10,48 +10,183 @@ char *test_simplex_1() { - double *U = Calloc(double, 2*1); + struct GenModel *model = gensvm_init_model(); + model->U = Calloc(double, 2*1); + model->K = 2; - gensvm_simplex(2, U); + gensvm_simplex(model); - mu_assert(matrix_get(U, 1, 0, 0) == -0.5, "U(0, 0) incorrect."); - mu_assert(matrix_get(U, 1, 1, 0) == 0.5, "U(1, 0) incorrect."); + mu_assert(fabs(matrix_get(model->U, 1, 0, 0) - -0.5) < 1e-14, + "U(0, 0) incorrect."); + mu_assert(fabs(matrix_get(model->U, 1, 1, 0) - 0.5) < 1e-14, + "U(1, 0) incorrect."); - free(U); + gensvm_free_model(model); return NULL; } char *test_simplex_2() { - double *U = Calloc(double, 4*3); - - gensvm_simplex(4, U); - - mu_assert(matrix_get(U, 3, 0, 0) == -0.5, "U(0, 0) incorrect."); - mu_assert(matrix_get(U, 3, 1, 0) == 0.5, "U(1, 0) incorrect."); - mu_assert(matrix_get(U, 3, 2, 0) == 0.0, "U(2, 0) incorrect."); - mu_assert(matrix_get(U, 3, 3, 0) == 0.0, "U(3, 0) incorrect."); - - mu_assert(fabs(matrix_get(U, 3, 0, 1) - -0.5/sqrt(3)) < 1e-14, + struct GenModel *model = gensvm_init_model(); + model->U = Calloc(double, 4*3); + model->K = 4; + + gensvm_simplex(model); + + mu_assert(fabs(matrix_get(model->U, 3, 0, 0) - -0.5) < 1e-14, + "U(0, 0) incorrect."); + mu_assert(fabs(matrix_get(model->U, 3, 1, 0) - 0.5) < 1e-14, + "U(1, 0) incorrect."); + mu_assert(fabs(matrix_get(model->U, 3, 2, 0) - 0.0) < 1e-14, + "U(2, 0) incorrect."); + mu_assert(fabs(matrix_get(model->U, 3, 3, 0) - 0.0) < 1e-14, + "U(3, 0) incorrect."); + + mu_assert(fabs(matrix_get(model->U, 3, 0, 1) - -0.5/sqrt(3)) < 1e-14, "U(0, 1) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 1, 1) - -0.5/sqrt(3)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 1, 1) - -0.5/sqrt(3)) < 1e-14, "U(1, 1) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 2, 1) - 1.0/sqrt(3)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 2, 1) - 1.0/sqrt(3)) < 1e-14, "U(2, 1) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 3, 1) - 0.0) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 3, 1) - 0.0) < 1e-14, "U(3, 1) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 0, 2) - -1.0/sqrt(24)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 0, 2) - -1.0/sqrt(24)) < 1e-14, "U(0, 2) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 1, 2) - -1.0/sqrt(24)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 1, 2) - -1.0/sqrt(24)) < 1e-14, "U(1, 2) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 2, 2) - -1.0/sqrt(24)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 2, 2) - -1.0/sqrt(24)) < 1e-14, "U(2, 2) incorrect."); - mu_assert(fabs(matrix_get(U, 3, 3, 2) - 3.0/sqrt(24)) < 1e-14, + mu_assert(fabs(matrix_get(model->U, 3, 3, 2) - 3.0/sqrt(24)) < 1e-14, "U(3, 2) incorrect."); - free(U); + gensvm_free_model(model); + + return NULL; +} + +char *test_gensvm_simplex_diff() +{ + struct GenData *data = gensvm_init_data(); + struct GenModel *model = gensvm_init_model(); + + int n = 8, + m = 3, + K = 4; + model->n = n; + model->m = m; + model->K = K; + data->n = n; + data->m = m; + data->K = K; + + gensvm_allocate_model(model); + gensvm_simplex(model); + + // start test code // + gensvm_simplex_diff(model); + mu_assert(fabs(model->UU[0] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 0"); + mu_assert(fabs(model->UU[1] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 1"); + mu_assert(fabs(model->UU[2] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 2"); + mu_assert(fabs(model->UU[3] - -1.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 3"); + mu_assert(fabs(model->UU[4] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 4"); + mu_assert(fabs(model->UU[5] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 5"); + mu_assert(fabs(model->UU[6] - -0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 6"); + mu_assert(fabs(model->UU[7] - -0.8660254037844388) < 1e-14, + "Incorrect value of model->UU at 7"); + mu_assert(fabs(model->UU[8] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 8"); + mu_assert(fabs(model->UU[9] - -0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 9"); + mu_assert(fabs(model->UU[10] - -0.2886751345948129) < 1e-14, + "Incorrect value of model->UU at 10"); + mu_assert(fabs(model->UU[11] - -0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 11"); + mu_assert(fabs(model->UU[12] - 1.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 12"); + mu_assert(fabs(model->UU[13] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 13"); + mu_assert(fabs(model->UU[14] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 14"); + mu_assert(fabs(model->UU[15] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 15"); + mu_assert(fabs(model->UU[16] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 16"); + mu_assert(fabs(model->UU[17] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 17"); + mu_assert(fabs(model->UU[18] - 0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 18"); + mu_assert(fabs(model->UU[19] - -0.8660254037844388) < 1e-14, + "Incorrect value of model->UU at 19"); + mu_assert(fabs(model->UU[20] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 20"); + mu_assert(fabs(model->UU[21] - 0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 21"); + mu_assert(fabs(model->UU[22] - -0.2886751345948129) < 1e-14, + "Incorrect value of model->UU at 22"); + mu_assert(fabs(model->UU[23] - -0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 23"); + mu_assert(fabs(model->UU[24] - 0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 24"); + mu_assert(fabs(model->UU[25] - 0.8660254037844388) < 1e-14, + "Incorrect value of model->UU at 25"); + mu_assert(fabs(model->UU[26] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 26"); + mu_assert(fabs(model->UU[27] - -0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 27"); + mu_assert(fabs(model->UU[28] - 0.8660254037844388) < 1e-14, + "Incorrect value of model->UU at 28"); + mu_assert(fabs(model->UU[29] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 29"); + mu_assert(fabs(model->UU[30] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 30"); + mu_assert(fabs(model->UU[31] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 31"); + mu_assert(fabs(model->UU[32] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 32"); + mu_assert(fabs(model->UU[33] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 33"); + mu_assert(fabs(model->UU[34] - 0.5773502691896258) < 1e-14, + "Incorrect value of model->UU at 34"); + mu_assert(fabs(model->UU[35] - -0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 35"); + mu_assert(fabs(model->UU[36] - 0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 36"); + mu_assert(fabs(model->UU[37] - 0.2886751345948129) < 1e-14, + "Incorrect value of model->UU at 37"); + mu_assert(fabs(model->UU[38] - 0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 38"); + mu_assert(fabs(model->UU[39] - -0.5000000000000000) < 1e-14, + "Incorrect value of model->UU at 39"); + mu_assert(fabs(model->UU[40] - 0.2886751345948129) < 1e-14, + "Incorrect value of model->UU at 40"); + mu_assert(fabs(model->UU[41] - 0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 41"); + mu_assert(fabs(model->UU[42] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 42"); + mu_assert(fabs(model->UU[43] - -0.5773502691896258) < 1e-14, + "Incorrect value of model->UU at 43"); + mu_assert(fabs(model->UU[44] - 0.8164965809277261) < 1e-14, + "Incorrect value of model->UU at 44"); + mu_assert(fabs(model->UU[45] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 45"); + mu_assert(fabs(model->UU[46] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 46"); + mu_assert(fabs(model->UU[47] - 0.0000000000000000) < 1e-14, + "Incorrect value of model->UU at 47"); + + // end test code // + + gensvm_free_model(model); + gensvm_free_data(data); return NULL; } @@ -61,6 +196,7 @@ char *all_tests() mu_suite_start(); mu_run_test(test_simplex_1); mu_run_test(test_simplex_2); + mu_run_test(test_gensvm_simplex_diff); return NULL; } |
