diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-10-17 13:41:46 +0200 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-10-17 13:41:46 +0200 |
| commit | e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba (patch) | |
| tree | 3564a0b9ed66ccf71d16bf54a304aad320876bbf /src | |
| parent | update doxyfile (diff) | |
| download | gensvm-e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba.tar.gz gensvm-e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba.zip | |
Update predictions to work with sparse matrices
This is done by pulling the Z*V routines from the gensvm_optimize file
to a seperate file, since they are shared by prediction and get_loss
Diffstat (limited to 'src')
| -rw-r--r-- | src/gensvm_optimize.c | 43 | ||||
| -rw-r--r-- | src/gensvm_pred.c | 19 | ||||
| -rw-r--r-- | src/gensvm_zv.c | 115 |
3 files changed, 118 insertions, 59 deletions
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c index e83fa47..300df40 100644 --- a/src/gensvm_optimize.c +++ b/src/gensvm_optimize.c @@ -257,10 +257,7 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data, long n = model->n; long K = model->K; - if (data->spZ == NULL) - gensvm_calculate_ZV_dense(model, data, ZV); - else - gensvm_calculate_ZV_sparse(model, data, ZV); + gensvm_calculate_ZV(model, data, ZV); for (i=0; i<n; i++) { for (j=0; j<K; j++) { @@ -273,41 +270,3 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data, } } -void gensvm_calculate_ZV_sparse(struct GenModel *model, - struct GenData *data, double *ZV) -{ - long i, j, jj, jj_start, jj_end, K, - n_row = data->spZ->n_row; - double z_ij; - - K = model->K; - - int *Zia = data->spZ->ia; - int *Zja = data->spZ->ja; - double *vals = data->spZ->values; - - for (i=0; i<n_row; i++) { - jj_start = Zia[i]; - jj_end = Zia[i+1]; - - for (jj=jj_start; jj<jj_end; jj++) { - j = Zja[jj]; - z_ij = vals[jj]; - - cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1, - &ZV[i*(K-1)], 1); - } - } -} - -void gensvm_calculate_ZV_dense(struct GenModel *model, - struct GenData *data, double *ZV) -{ - long n = model->n; - long m = model->m; - long K = model->K; - - cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1, - 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1); -} - diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c index afa1ab9..31b591c 100644 --- a/src/gensvm_pred.c +++ b/src/gensvm_pred.c @@ -30,13 +30,12 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, long *predy) { - long i, j, k, n, m, K, label; + long i, j, k, n, K, label; double norm, min_dist, *S = NULL, *ZV = NULL; n = testdata->n; - m = testdata->r; K = model->K; // allocate necessary memory @@ -47,21 +46,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, gensvm_simplex(model); // Generate the simplex space vectors - cblas_dgemm( - CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - n, - K-1, - m+1, - 1.0, - testdata->Z, - m+1, - model->V, - K-1, - 0.0, - ZV, - K-1); + gensvm_calculate_ZV(model, testdata, ZV); // Calculate the distance to each of the vertices of the simplex. // The closest vertex defines the class label diff --git a/src/gensvm_zv.c b/src/gensvm_zv.c new file mode 100644 index 0000000..81a5354 --- /dev/null +++ b/src/gensvm_zv.c @@ -0,0 +1,115 @@ +/** + * @file gensvm_zv.c + * @author Gertjan van den Burg + * @date 2016-10-17 + * @brief Functions for computing the ZV matrix product + * + * @details + * This file exists because the product Z*V of two matrices occurs both in the + * computation of the loss function and for predicting class labels. Moreover, + * a distinction has to be made between dense Z matrices and sparse Z + * matrices, hence a seperate file is warranted. + + * Copyright (C) + + This program is free software; you can redistribute it and/or + modify it under the terms of the GNU General Public License + as published by the Free Software Foundation; either version 2 + of the License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, write to the Free Software + Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + + */ + +#include "gensvm_zv.h" + +/** + * @brief Wrapper around sparse/dense versions of this function + * + * @details + * This function tests if the data is stored in dense format or sparse format + * by testing if GenData::Z is NULL or not, and calls the corresponding + * version of this function accordingly. + * + * @sa + * gensvm_calculate_ZV_dense(), gensvm_calculate_ZV_sparse() + * + * @param[in] model a GenModel instance holding the model + * @param[in] data a GenData instance with the data + * @param[out] ZV a pre-allocated matrix of appropriate dimensions + */ +void gensvm_calculate_ZV(struct GenModel *model, struct GenData *data, + double *ZV) +{ + if (data->Z == NULL) + gensvm_calculate_ZV_sparse(model, data, ZV); + else + gensvm_calculate_ZV_dense(model, data, ZV); +} + +/** + * @brief Compute the product Z*V for when Z is a sparse matrix + * + * @details + * This is a simple sparse-dense matrix multiplication, which uses + * cblas_daxpy() for each nonzero element of Z, to compute Z*V. + * + * @param[in] model a GenModel instance holding the model + * @param[in] data a GenData instance with the data + * @param[out] ZV a pre-allocated matrix of appropriate dimensions + */ +void gensvm_calculate_ZV_sparse(struct GenModel *model, + struct GenData *data, double *ZV) +{ + long i, j, jj, jj_start, jj_end, K, + n_row = data->spZ->n_row; + double z_ij; + + K = model->K; + + int *Zia = data->spZ->ia; + int *Zja = data->spZ->ja; + double *vals = data->spZ->values; + + for (i=0; i<n_row; i++) { + jj_start = Zia[i]; + jj_end = Zia[i+1]; + + for (jj=jj_start; jj<jj_end; jj++) { + j = Zja[jj]; + z_ij = vals[jj]; + + cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1, + &ZV[i*(K-1)], 1); + } + } +} + +/** + * @brief Compute the product Z*V for when Z is a dense matrix + * + * @details + * This function uses cblas_dgemm() to compute the matrix product between Z + * and V. + * + * @param[in] model a GenModel instance holding the model + * @param[in] data a GenData instance with the data + * @param[out] ZV a pre-allocated matrix of appropriate dimensions + */ +void gensvm_calculate_ZV_dense(struct GenModel *model, + struct GenData *data, double *ZV) +{ + long n = model->n; + long m = model->m; + long K = model->K; + + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1, + 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1); +} |
