From e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Mon, 17 Oct 2016 13:41:46 +0200 Subject: Update predictions to work with sparse matrices This is done by pulling the Z*V routines from the gensvm_optimize file to a seperate file, since they are shared by prediction and get_loss --- tests/src/test_gensvm_pred.c | 103 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) (limited to 'tests/src/test_gensvm_pred.c') diff --git a/tests/src/test_gensvm_pred.c b/tests/src/test_gensvm_pred.c index 13f0e5a..155f0bf 100644 --- a/tests/src/test_gensvm_pred.c +++ b/tests/src/test_gensvm_pred.c @@ -24,7 +24,7 @@ * V = [zeros(1, K-1); R]; * */ -char *test_gensvm_predict_labels() +char *test_gensvm_predict_labels_dense() { int n = 12; int m = 2; @@ -116,6 +116,104 @@ char *test_gensvm_predict_labels() return NULL; } +char *test_gensvm_predict_labels_sparse() +{ + int n = 12; + int m = 2; + int K = 3; + + struct GenData *data = gensvm_init_data(); + struct GenModel *model = gensvm_init_model(); + + model->n = n; + model->m = m; + model->K = K; + + data->n = n; + data->m = m; + data->r = m; + data->K = K; + + data->Z = Calloc(double, n*(m+1)); + data->y = Calloc(long, n); + + matrix_set(data->Z, m+1, 0, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 0, 1, -0.3943375672974065); + matrix_set(data->Z, m+1, 0, 2, -0.1056624327025935); + matrix_set(data->Z, m+1, 1, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 1, 1, -0.2886751345948129); + matrix_set(data->Z, m+1, 1, 2, -0.2886751345948128); + matrix_set(data->Z, m+1, 2, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 2, 1, -0.1056624327025937); + matrix_set(data->Z, m+1, 2, 2, -0.3943375672974063); + matrix_set(data->Z, m+1, 3, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 3, 1, 0.1056624327025935); + matrix_set(data->Z, m+1, 3, 2, -0.3943375672974064); + matrix_set(data->Z, m+1, 4, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 4, 1, 0.2886751345948129); + matrix_set(data->Z, m+1, 4, 2, -0.2886751345948129); + matrix_set(data->Z, m+1, 5, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 5, 1, 0.3943375672974064); + matrix_set(data->Z, m+1, 5, 2, -0.1056624327025937); + matrix_set(data->Z, m+1, 6, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 6, 1, 0.3943375672974065); + matrix_set(data->Z, m+1, 6, 2, 0.1056624327025935); + matrix_set(data->Z, m+1, 7, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 7, 1, 0.2886751345948130); + matrix_set(data->Z, m+1, 7, 2, 0.2886751345948128); + matrix_set(data->Z, m+1, 8, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 8, 1, 0.1056624327025939); + matrix_set(data->Z, m+1, 8, 2, 0.3943375672974063); + matrix_set(data->Z, m+1, 9, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 9, 1, -0.1056624327025934); + matrix_set(data->Z, m+1, 9, 2, 0.3943375672974064); + matrix_set(data->Z, m+1, 10, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 10, 1, -0.2886751345948126); + matrix_set(data->Z, m+1, 10, 2, 0.2886751345948132); + matrix_set(data->Z, m+1, 11, 0, 1.0000000000000000); + matrix_set(data->Z, m+1, 11, 1, -0.3943375672974064); + matrix_set(data->Z, m+1, 11, 2, 0.1056624327025939); + + // convert Z to a sparse matrix to test the sparse functions + data->spZ = gensvm_dense_to_sparse(data->Z, data->n, data->m+1); + free(data->Z); + data->RAW = NULL; + data->Z = NULL; + + gensvm_allocate_model(model); + + matrix_set(model->V, K-1, 0, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 0, 1, 0.0000000000000000); + matrix_set(model->V, K-1, 1, 0, -2.4494897427831779); + matrix_set(model->V, K-1, 1, 1, -0.0000000000000002); + matrix_set(model->V, K-1, 2, 0, 0.0000000000000000); + matrix_set(model->V, K-1, 2, 1, -2.4494897427831783); + + // start test code + long *predy = Calloc(long, n); + gensvm_predict_labels(data, model, predy); + + mu_assert(predy[0] == 2, "Incorrect label at index 0"); + mu_assert(predy[1] == 3, "Incorrect label at index 1"); + mu_assert(predy[2] == 3, "Incorrect label at index 2"); + mu_assert(predy[3] == 3, "Incorrect label at index 3"); + mu_assert(predy[4] == 3, "Incorrect label at index 4"); + mu_assert(predy[5] == 1, "Incorrect label at index 5"); + mu_assert(predy[6] == 1, "Incorrect label at index 6"); + mu_assert(predy[7] == 1, "Incorrect label at index 7"); + mu_assert(predy[8] == 1, "Incorrect label at index 8"); + mu_assert(predy[9] == 2, "Incorrect label at index 9"); + mu_assert(predy[10] == 2, "Incorrect label at index 10"); + mu_assert(predy[11] == 2, "Incorrect label at index 11"); + + // end test code + gensvm_free_data(data); + gensvm_free_model(model); + free(predy); + + return NULL; +} + char *test_gensvm_prediction_perf() { int i, n = 8; @@ -156,7 +254,8 @@ char *test_gensvm_prediction_perf() char *all_tests() { mu_suite_start(); - mu_run_test(test_gensvm_predict_labels); + mu_run_test(test_gensvm_predict_labels_dense); + mu_run_test(test_gensvm_predict_labels_sparse); mu_run_test(test_gensvm_prediction_perf); return NULL; -- cgit v1.2.3