aboutsummaryrefslogtreecommitdiff
path: root/tests/src/test_gensvm_optimize.c
diff options
context:
space:
mode:
Diffstat (limited to 'tests/src/test_gensvm_optimize.c')
-rw-r--r--tests/src/test_gensvm_optimize.c43
1 files changed, 21 insertions, 22 deletions
diff --git a/tests/src/test_gensvm_optimize.c b/tests/src/test_gensvm_optimize.c
index de7e074..d7326f7 100644
--- a/tests/src/test_gensvm_optimize.c
+++ b/tests/src/test_gensvm_optimize.c
@@ -137,12 +137,17 @@ char *test_gensvm_get_loss_1()
int n = 8,
m = 3,
K = 3;
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ struct GenWork *work = gensvm_init_work(model);
// initialize the data
data->n = n;
data->K = K;
data->m = m;
+
data->y = Calloc(long, data->n);
data->y[0] = 2;
data->y[1] = 1;
@@ -188,9 +193,6 @@ char *test_gensvm_get_loss_1()
matrix_set(data->Z, data->m+1, 7, 3, 0.0540516162042732);
// initialize the model
- model->n = n;
- model->m = m;
- model->K = K;
model->weight_idx = 1;
model->kappa = 0.5;
model->p = 1.5;
@@ -211,17 +213,16 @@ char *test_gensvm_get_loss_1()
matrix_set(model->V, model->K-1, 3, 1, 0.5912336906588613);
// start test code //
- double *ZV = Calloc(double, (data->n)*(data->K-1));
- double loss = gensvm_get_loss(model, data, ZV);
+ double loss = gensvm_get_loss(model, data, work);
mu_assert(fabs(loss - 0.903071383013108) < 1e-14,
"Incorrect value of the loss");
- free(ZV);
// end test code //
gensvm_free_model(model);
gensvm_free_data(data);
+ gensvm_free_work(work);
return NULL;
@@ -235,6 +236,10 @@ char *test_gensvm_get_loss_2()
int n = 8,
m = 3,
K = 3;
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ struct GenWork *work = gensvm_init_work(model);
// initialize the data
data->n = n;
@@ -286,9 +291,6 @@ char *test_gensvm_get_loss_2()
matrix_set(data->Z, data->m+1, 7, 3, 0.0540516162042732);
// initialize the model
- model->n = n;
- model->m = m;
- model->K = K;
model->weight_idx = 2;
model->kappa = 0.5;
model->p = 1.5;
@@ -309,17 +311,15 @@ char *test_gensvm_get_loss_2()
matrix_set(model->V, model->K-1, 3, 1, 0.5912336906588613);
// start test code //
- double *ZV = Calloc(double, (data->n)*(data->K-1));
- double loss = gensvm_get_loss(model, data, ZV);
+ double loss = gensvm_get_loss(model, data, work);
mu_assert(fabs(loss - 0.972847045993281) < 1e-14,
"Incorrect value of the loss");
-
- free(ZV);
// end test code //
gensvm_free_model(model);
gensvm_free_data(data);
+ gensvm_free_work(work);
return NULL;
}
@@ -652,6 +652,11 @@ char *test_gensvm_get_update()
m = 3,
K = 3;
+ model->n = n;
+ model->m = m;
+ model->K = K;
+ struct GenWork *work = gensvm_init_work(model);
+
// initialize data
data->n = n;
data->m = m;
@@ -702,9 +707,6 @@ char *test_gensvm_get_update()
matrix_set(data->Z, data->m+1, 7, 3, -0.7292593770500276);
// initialize model
- model->n = n;
- model->m = m;
- model->K = K;
model->p = 1.1;
model->lambda = 0.123;
model->weight_idx = 1;
@@ -727,14 +729,13 @@ char *test_gensvm_get_update()
matrix_set(model->V, model->K-1, 3, 1, 0.7134997072555367);
// start test code //
- double *ZV = Calloc(double, n*(K-1));
// these need to be prepared for the update call
- gensvm_calculate_errors(model, data, ZV);
+ gensvm_calculate_errors(model, data, work->ZV);
gensvm_calculate_huber(model);
// run the actual update call
- gensvm_get_update(model, data);
+ gensvm_get_update(model, data, work);
// test values
mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 0) -
@@ -761,13 +762,11 @@ char *test_gensvm_get_update()
mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 1) -
0.4390030236354089) < 1e-14,
"Incorrect value of model->V at 3, 1");
-
- free(ZV);
-
// end test code //
gensvm_free_model(model);
gensvm_free_data(data);
+ gensvm_free_work(work);
return NULL;
}