From f3f55565711893004df14cc4c6ffd86f0b736f2f Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Thu, 6 Oct 2016 16:45:00 +0200 Subject: Switch to using dsyrk instead of dsyr for speed. Also added a workspace (GenWork) structure for to hold working matrices for the gensvm_get_update() and gensvm_get_loss() functions --- include/gensvm_base.h | 31 +++++++++++++++++++++++++++++++ include/gensvm_optimize.h | 7 ++++--- 2 files changed, 35 insertions(+), 3 deletions(-) (limited to 'include') diff --git a/include/gensvm_base.h b/include/gensvm_base.h index efeaa4d..03b7ffa 100644 --- a/include/gensvm_base.h +++ b/include/gensvm_base.h @@ -99,12 +99,43 @@ struct GenModel { ///< array of kernel parameters, size depends on kernel type }; +/** + * @brief A structure to hold the GenSVM workspace + * + */ +struct GenWork { + long n; + ///< number of instances for the workspace + long m; + ///< number of features for the workspace + long K; + ///< number of classes for the workspace + + double *LZ; + ///< n x (m+1) working matrix for the Z'*A*Z calculation + double *ZB; + ///< (m+1) x (K-1) working matrix for the Z'*B calculation + double *ZBc; + ///< (K-1) x (m+1) working matrix for the Z'*B calculation + double *ZAZ; + ///< (m+1) x (m+1) working matrix for the Z'*A*Z calculation + double *ZV; + ///< n x (K-1) working matrix for the Z * V calculation + double *beta; + ///< K-1 working vector for a row of the B matrix +}; + // function declarations struct GenModel *gensvm_init_model(); void gensvm_allocate_model(struct GenModel *model); void gensvm_reallocate_model(struct GenModel *model, long n, long m); void gensvm_free_model(struct GenModel *model); + struct GenData *gensvm_init_data(); void gensvm_free_data(struct GenData *data); +struct GenWork *gensvm_init_work(struct GenModel *model); +void gensvm_free_work(struct GenWork *work); +void gensvm_reset_work(struct GenWork *work); + #endif diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h index bbdf4c8..dec8914 100644 --- a/include/gensvm_optimize.h +++ b/include/gensvm_optimize.h @@ -19,8 +19,8 @@ // function declarations void gensvm_optimize(struct GenModel *model, struct GenData *data); -double gensvm_get_loss(struct GenModel *model, struct GenData *data, - double *ZV); +double gensvm_get_loss(struct GenModel *model, struct GenData *data, + struct GenWork *work); double gensvm_calculate_omega(struct GenModel *model, struct GenData *data, long i); @@ -33,7 +33,8 @@ void gensvm_calculate_ab_simple(struct GenModel *model, long i, long j, double gensvm_get_alpha_beta(struct GenModel *model, struct GenData *data, long i, double *beta); -void gensvm_get_update(struct GenModel *model, struct GenData *data); +void gensvm_get_update(struct GenModel *model, struct GenData *data, + struct GenWork *work); void gensvm_calculate_errors(struct GenModel *model, struct GenData *data, double *ZV); void gensvm_calculate_huber(struct GenModel *model); -- cgit v1.2.3