diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-10-06 16:45:00 +0200 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-10-06 16:45:00 +0200 |
| commit | f3f55565711893004df14cc4c6ffd86f0b736f2f (patch) | |
| tree | 814e51de8a47ee01ff10620552cb1660577f02c7 /tests/src/test_gensvm_optimize.c | |
| parent | documentation fixes (diff) | |
| download | gensvm-f3f55565711893004df14cc4c6ffd86f0b736f2f.tar.gz gensvm-f3f55565711893004df14cc4c6ffd86f0b736f2f.zip | |
Switch to using dsyrk instead of dsyr for speed.
Also added a workspace (GenWork) structure for to hold working matrices
for the gensvm_get_update() and gensvm_get_loss() functions
Diffstat (limited to 'tests/src/test_gensvm_optimize.c')
| -rw-r--r-- | tests/src/test_gensvm_optimize.c | 43 |
1 files changed, 21 insertions, 22 deletions
diff --git a/tests/src/test_gensvm_optimize.c b/tests/src/test_gensvm_optimize.c index de7e074..d7326f7 100644 --- a/tests/src/test_gensvm_optimize.c +++ b/tests/src/test_gensvm_optimize.c @@ -137,12 +137,17 @@ char *test_gensvm_get_loss_1() int n = 8, m = 3, K = 3; + model->n = n; + model->m = m; + model->K = K; + struct GenWork *work = gensvm_init_work(model); // initialize the data data->n = n; data->K = K; data->m = m; + data->y = Calloc(long, data->n); data->y[0] = 2; data->y[1] = 1; @@ -188,9 +193,6 @@ char *test_gensvm_get_loss_1() matrix_set(data->Z, data->m+1, 7, 3, 0.0540516162042732); // initialize the model - model->n = n; - model->m = m; - model->K = K; model->weight_idx = 1; model->kappa = 0.5; model->p = 1.5; @@ -211,17 +213,16 @@ char *test_gensvm_get_loss_1() matrix_set(model->V, model->K-1, 3, 1, 0.5912336906588613); // start test code // - double *ZV = Calloc(double, (data->n)*(data->K-1)); - double loss = gensvm_get_loss(model, data, ZV); + double loss = gensvm_get_loss(model, data, work); mu_assert(fabs(loss - 0.903071383013108) < 1e-14, "Incorrect value of the loss"); - free(ZV); // end test code // gensvm_free_model(model); gensvm_free_data(data); + gensvm_free_work(work); return NULL; @@ -235,6 +236,10 @@ char *test_gensvm_get_loss_2() int n = 8, m = 3, K = 3; + model->n = n; + model->m = m; + model->K = K; + struct GenWork *work = gensvm_init_work(model); // initialize the data data->n = n; @@ -286,9 +291,6 @@ char *test_gensvm_get_loss_2() matrix_set(data->Z, data->m+1, 7, 3, 0.0540516162042732); // initialize the model - model->n = n; - model->m = m; - model->K = K; model->weight_idx = 2; model->kappa = 0.5; model->p = 1.5; @@ -309,17 +311,15 @@ char *test_gensvm_get_loss_2() matrix_set(model->V, model->K-1, 3, 1, 0.5912336906588613); // start test code // - double *ZV = Calloc(double, (data->n)*(data->K-1)); - double loss = gensvm_get_loss(model, data, ZV); + double loss = gensvm_get_loss(model, data, work); mu_assert(fabs(loss - 0.972847045993281) < 1e-14, "Incorrect value of the loss"); - - free(ZV); // end test code // gensvm_free_model(model); gensvm_free_data(data); + gensvm_free_work(work); return NULL; } @@ -652,6 +652,11 @@ char *test_gensvm_get_update() m = 3, K = 3; + model->n = n; + model->m = m; + model->K = K; + struct GenWork *work = gensvm_init_work(model); + // initialize data data->n = n; data->m = m; @@ -702,9 +707,6 @@ char *test_gensvm_get_update() matrix_set(data->Z, data->m+1, 7, 3, -0.7292593770500276); // initialize model - model->n = n; - model->m = m; - model->K = K; model->p = 1.1; model->lambda = 0.123; model->weight_idx = 1; @@ -727,14 +729,13 @@ char *test_gensvm_get_update() matrix_set(model->V, model->K-1, 3, 1, 0.7134997072555367); // start test code // - double *ZV = Calloc(double, n*(K-1)); // these need to be prepared for the update call - gensvm_calculate_errors(model, data, ZV); + gensvm_calculate_errors(model, data, work->ZV); gensvm_calculate_huber(model); // run the actual update call - gensvm_get_update(model, data); + gensvm_get_update(model, data, work); // test values mu_assert(fabs(matrix_get(model->V, model->K-1, 0, 0) - @@ -761,13 +762,11 @@ char *test_gensvm_get_update() mu_assert(fabs(matrix_get(model->V, model->K-1, 3, 1) - 0.4390030236354089) < 1e-14, "Incorrect value of model->V at 3, 1"); - - free(ZV); - // end test code // gensvm_free_model(model); gensvm_free_data(data); + gensvm_free_work(work); return NULL; } |
