aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2016-10-17 13:41:46 +0200
committerGertjan van den Burg <burg@ese.eur.nl>2016-10-17 13:41:46 +0200
commite2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba (patch)
tree3564a0b9ed66ccf71d16bf54a304aad320876bbf /src
parentupdate doxyfile (diff)
downloadgensvm-e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba.tar.gz
gensvm-e2c0ca1c082bfd7755c7af5bc5c9021bce64f3ba.zip
Update predictions to work with sparse matrices
This is done by pulling the Z*V routines from the gensvm_optimize file to a seperate file, since they are shared by prediction and get_loss
Diffstat (limited to 'src')
-rw-r--r--src/gensvm_optimize.c43
-rw-r--r--src/gensvm_pred.c19
-rw-r--r--src/gensvm_zv.c115
3 files changed, 118 insertions, 59 deletions
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c
index e83fa47..300df40 100644
--- a/src/gensvm_optimize.c
+++ b/src/gensvm_optimize.c
@@ -257,10 +257,7 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
long n = model->n;
long K = model->K;
- if (data->spZ == NULL)
- gensvm_calculate_ZV_dense(model, data, ZV);
- else
- gensvm_calculate_ZV_sparse(model, data, ZV);
+ gensvm_calculate_ZV(model, data, ZV);
for (i=0; i<n; i++) {
for (j=0; j<K; j++) {
@@ -273,41 +270,3 @@ void gensvm_calculate_errors(struct GenModel *model, struct GenData *data,
}
}
-void gensvm_calculate_ZV_sparse(struct GenModel *model,
- struct GenData *data, double *ZV)
-{
- long i, j, jj, jj_start, jj_end, K,
- n_row = data->spZ->n_row;
- double z_ij;
-
- K = model->K;
-
- int *Zia = data->spZ->ia;
- int *Zja = data->spZ->ja;
- double *vals = data->spZ->values;
-
- for (i=0; i<n_row; i++) {
- jj_start = Zia[i];
- jj_end = Zia[i+1];
-
- for (jj=jj_start; jj<jj_end; jj++) {
- j = Zja[jj];
- z_ij = vals[jj];
-
- cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1,
- &ZV[i*(K-1)], 1);
- }
- }
-}
-
-void gensvm_calculate_ZV_dense(struct GenModel *model,
- struct GenData *data, double *ZV)
-{
- long n = model->n;
- long m = model->m;
- long K = model->K;
-
- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1,
- 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1);
-}
-
diff --git a/src/gensvm_pred.c b/src/gensvm_pred.c
index afa1ab9..31b591c 100644
--- a/src/gensvm_pred.c
+++ b/src/gensvm_pred.c
@@ -30,13 +30,12 @@
void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
long *predy)
{
- long i, j, k, n, m, K, label;
+ long i, j, k, n, K, label;
double norm, min_dist,
*S = NULL,
*ZV = NULL;
n = testdata->n;
- m = testdata->r;
K = model->K;
// allocate necessary memory
@@ -47,21 +46,7 @@ void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model,
gensvm_simplex(model);
// Generate the simplex space vectors
- cblas_dgemm(
- CblasRowMajor,
- CblasNoTrans,
- CblasNoTrans,
- n,
- K-1,
- m+1,
- 1.0,
- testdata->Z,
- m+1,
- model->V,
- K-1,
- 0.0,
- ZV,
- K-1);
+ gensvm_calculate_ZV(model, testdata, ZV);
// Calculate the distance to each of the vertices of the simplex.
// The closest vertex defines the class label
diff --git a/src/gensvm_zv.c b/src/gensvm_zv.c
new file mode 100644
index 0000000..81a5354
--- /dev/null
+++ b/src/gensvm_zv.c
@@ -0,0 +1,115 @@
+/**
+ * @file gensvm_zv.c
+ * @author Gertjan van den Burg
+ * @date 2016-10-17
+ * @brief Functions for computing the ZV matrix product
+ *
+ * @details
+ * This file exists because the product Z*V of two matrices occurs both in the
+ * computation of the loss function and for predicting class labels. Moreover,
+ * a distinction has to be made between dense Z matrices and sparse Z
+ * matrices, hence a seperate file is warranted.
+
+ * Copyright (C)
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License
+ as published by the Free Software Foundation; either version 2
+ of the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+
+ */
+
+#include "gensvm_zv.h"
+
+/**
+ * @brief Wrapper around sparse/dense versions of this function
+ *
+ * @details
+ * This function tests if the data is stored in dense format or sparse format
+ * by testing if GenData::Z is NULL or not, and calls the corresponding
+ * version of this function accordingly.
+ *
+ * @sa
+ * gensvm_calculate_ZV_dense(), gensvm_calculate_ZV_sparse()
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV(struct GenModel *model, struct GenData *data,
+ double *ZV)
+{
+ if (data->Z == NULL)
+ gensvm_calculate_ZV_sparse(model, data, ZV);
+ else
+ gensvm_calculate_ZV_dense(model, data, ZV);
+}
+
+/**
+ * @brief Compute the product Z*V for when Z is a sparse matrix
+ *
+ * @details
+ * This is a simple sparse-dense matrix multiplication, which uses
+ * cblas_daxpy() for each nonzero element of Z, to compute Z*V.
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV_sparse(struct GenModel *model,
+ struct GenData *data, double *ZV)
+{
+ long i, j, jj, jj_start, jj_end, K,
+ n_row = data->spZ->n_row;
+ double z_ij;
+
+ K = model->K;
+
+ int *Zia = data->spZ->ia;
+ int *Zja = data->spZ->ja;
+ double *vals = data->spZ->values;
+
+ for (i=0; i<n_row; i++) {
+ jj_start = Zia[i];
+ jj_end = Zia[i+1];
+
+ for (jj=jj_start; jj<jj_end; jj++) {
+ j = Zja[jj];
+ z_ij = vals[jj];
+
+ cblas_daxpy(K-1, z_ij, &model->V[j*(K-1)], 1,
+ &ZV[i*(K-1)], 1);
+ }
+ }
+}
+
+/**
+ * @brief Compute the product Z*V for when Z is a dense matrix
+ *
+ * @details
+ * This function uses cblas_dgemm() to compute the matrix product between Z
+ * and V.
+ *
+ * @param[in] model a GenModel instance holding the model
+ * @param[in] data a GenData instance with the data
+ * @param[out] ZV a pre-allocated matrix of appropriate dimensions
+ */
+void gensvm_calculate_ZV_dense(struct GenModel *model,
+ struct GenData *data, double *ZV)
+{
+ long n = model->n;
+ long m = model->m;
+ long K = model->K;
+
+ cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, K-1, m+1,
+ 1.0, data->Z, m+1, model->V, K-1, 0, ZV, K-1);
+}