aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2020-07-12 23:37:57 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2020-07-12 23:37:57 +0100
commit19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b (patch)
tree58cac0e2c30b7971d187deab2701ee70657b74a9
parentRemove unused OPTFLAGS (diff)
downloadgensvm-19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b.tar.gz
gensvm-19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b.zip
Print training accuracyHEADmaster
-rw-r--r--include/gensvm_base.h2
-rw-r--r--include/gensvm_optimize.h1
-rw-r--r--src/gensvm_base.c3
-rw-r--r--src/gensvm_optimize.c19
4 files changed, 20 insertions, 5 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h
index f5f205f..16b7508 100644
--- a/include/gensvm_base.h
+++ b/include/gensvm_base.h
@@ -172,6 +172,8 @@ struct GenWork {
///< n x (K-1) working matrix for the Z * V calculation
double *beta;
///< K-1 working vector for a row of the B matrix
+ long *yhat;
+ ///< n vector of predicted classes
};
// function declarations
diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h
index d7a8248..e12f0f5 100644
--- a/include/gensvm_optimize.h
+++ b/include/gensvm_optimize.h
@@ -33,6 +33,7 @@
#include "gensvm_sv.h"
#include "gensvm_simplex.h"
+#include "gensvm_predict.h"
#include "gensvm_update.h"
#include "gensvm_zv.h"
diff --git a/src/gensvm_base.c b/src/gensvm_base.c
index 52e1e82..c5d036b 100644
--- a/src/gensvm_base.c
+++ b/src/gensvm_base.c
@@ -265,6 +265,7 @@ struct GenWork *gensvm_init_work(struct GenModel *model)
work->tmpZAZ = Calloc(double, (m+1)*(m+1)),
work->ZV = Calloc(double, n*(K-1));
work->beta = Calloc(double, K-1);
+ work->yhat = Calloc(long, n);
return work;
}
@@ -288,6 +289,7 @@ void gensvm_free_work(struct GenWork *work)
free(work->tmpZAZ);
free(work->ZV);
free(work->beta);
+ free(work->yhat);
free(work);
work = NULL;
}
@@ -317,4 +319,5 @@ void gensvm_reset_work(struct GenWork *work)
Memset(work->tmpZAZ, double, (m+1)*(m+1)),
Memset(work->ZV, double, n*(K-1));
Memset(work->beta, double, K-1);
+ Memset(work->yhat, long, n);
}
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c
index df46ec4..c6c8538 100644
--- a/src/gensvm_optimize.c
+++ b/src/gensvm_optimize.c
@@ -56,7 +56,7 @@
void gensvm_optimize(struct GenModel *model, struct GenData *data)
{
long it = 0;
- double L, Lbar;
+ double L, Lbar, acc;
long n = model->n;
long m = model->m;
@@ -98,9 +98,14 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data)
Lbar = L;
L = gensvm_get_loss(model, data, work);
- if (it % GENSVM_PRINT_ITER == 0)
+ if (it % GENSVM_PRINT_ITER == 0) {
+ gensvm_predict_labels(data, model, work->yhat);
+ acc = gensvm_prediction_perf(data, work->yhat);
note("iter = %li, L = %15.16f, Lbar = %15.16f, "
- "reldiff = %15.16f\n", it, L, Lbar, (Lbar - L)/L);
+ "reldiff = %15.16f, acc = %.2f\n", it, L, Lbar,
+ (Lbar - L)/L, acc);
+ }
+
it++;
}
@@ -120,10 +125,14 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data)
model->status = 2;
}
+ // compute final training accuracy
+ gensvm_predict_labels(data, model, work->yhat);
+ acc = gensvm_prediction_perf(data, work->yhat);
+
// print final iteration count and loss
note("Optimization finished, iter = %li, loss = %15.16f, "
- "rel. diff. = %15.16f\n", it-1, L,
- (Lbar - L)/L);
+ "rel. diff. = %15.16f, acc = %.2f\n", it-1, L,
+ (Lbar - L)/L, acc);
// compute and print the number of SVs in the model
note("Number of support vectors: %li\n", gensvm_num_sv(model));