diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-07-12 23:37:57 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2020-07-12 23:37:57 +0100 |
| commit | 19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b (patch) | |
| tree | 58cac0e2c30b7971d187deab2701ee70657b74a9 | |
| parent | Remove unused OPTFLAGS (diff) | |
| download | gensvm-19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b.tar.gz gensvm-19d98bb3d3ddf7c941c0d1e9df9e3614e0ccd68b.zip | |
| -rw-r--r-- | include/gensvm_base.h | 2 | ||||
| -rw-r--r-- | include/gensvm_optimize.h | 1 | ||||
| -rw-r--r-- | src/gensvm_base.c | 3 | ||||
| -rw-r--r-- | src/gensvm_optimize.c | 19 |
4 files changed, 20 insertions, 5 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h index f5f205f..16b7508 100644 --- a/include/gensvm_base.h +++ b/include/gensvm_base.h @@ -172,6 +172,8 @@ struct GenWork { ///< n x (K-1) working matrix for the Z * V calculation double *beta; ///< K-1 working vector for a row of the B matrix + long *yhat; + ///< n vector of predicted classes }; // function declarations diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h index d7a8248..e12f0f5 100644 --- a/include/gensvm_optimize.h +++ b/include/gensvm_optimize.h @@ -33,6 +33,7 @@ #include "gensvm_sv.h" #include "gensvm_simplex.h" +#include "gensvm_predict.h" #include "gensvm_update.h" #include "gensvm_zv.h" diff --git a/src/gensvm_base.c b/src/gensvm_base.c index 52e1e82..c5d036b 100644 --- a/src/gensvm_base.c +++ b/src/gensvm_base.c @@ -265,6 +265,7 @@ struct GenWork *gensvm_init_work(struct GenModel *model) work->tmpZAZ = Calloc(double, (m+1)*(m+1)), work->ZV = Calloc(double, n*(K-1)); work->beta = Calloc(double, K-1); + work->yhat = Calloc(long, n); return work; } @@ -288,6 +289,7 @@ void gensvm_free_work(struct GenWork *work) free(work->tmpZAZ); free(work->ZV); free(work->beta); + free(work->yhat); free(work); work = NULL; } @@ -317,4 +319,5 @@ void gensvm_reset_work(struct GenWork *work) Memset(work->tmpZAZ, double, (m+1)*(m+1)), Memset(work->ZV, double, n*(K-1)); Memset(work->beta, double, K-1); + Memset(work->yhat, long, n); } diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c index df46ec4..c6c8538 100644 --- a/src/gensvm_optimize.c +++ b/src/gensvm_optimize.c @@ -56,7 +56,7 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data) { long it = 0; - double L, Lbar; + double L, Lbar, acc; long n = model->n; long m = model->m; @@ -98,9 +98,14 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data) Lbar = L; L = gensvm_get_loss(model, data, work); - if (it % GENSVM_PRINT_ITER == 0) + if (it % GENSVM_PRINT_ITER == 0) { + gensvm_predict_labels(data, model, work->yhat); + acc = gensvm_prediction_perf(data, work->yhat); note("iter = %li, L = %15.16f, Lbar = %15.16f, " - "reldiff = %15.16f\n", it, L, Lbar, (Lbar - L)/L); + "reldiff = %15.16f, acc = %.2f\n", it, L, Lbar, + (Lbar - L)/L, acc); + } + it++; } @@ -120,10 +125,14 @@ void gensvm_optimize(struct GenModel *model, struct GenData *data) model->status = 2; } + // compute final training accuracy + gensvm_predict_labels(data, model, work->yhat); + acc = gensvm_prediction_perf(data, work->yhat); + // print final iteration count and loss note("Optimization finished, iter = %li, loss = %15.16f, " - "rel. diff. = %15.16f\n", it-1, L, - (Lbar - L)/L); + "rel. diff. = %15.16f, acc = %.2f\n", it-1, L, + (Lbar - L)/L, acc); // compute and print the number of SVs in the model note("Number of support vectors: %li\n", gensvm_num_sv(model)); |
