aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorGertjan van den Burg <burg@ese.eur.nl>2013-08-06 18:16:40 +0200
committerGertjan van den Burg <burg@ese.eur.nl>2013-08-06 18:16:40 +0200
commit0c4ff1eeb84199e3443f5a534efdb85402c53459 (patch)
tree0457b89b18c2c977ef7c22fc0e37c0bb1cfec461 /include
parentadded datasets and README (diff)
downloadgensvm-0c4ff1eeb84199e3443f5a534efdb85402c53459.tar.gz
gensvm-0c4ff1eeb84199e3443f5a534efdb85402c53459.zip
Added prediction script and model i/o.
Diffstat (limited to 'include')
-rw-r--r--include/libMSVMMaj.h5
-rw-r--r--include/util.h9
2 files changed, 12 insertions, 2 deletions
diff --git a/include/libMSVMMaj.h b/include/libMSVMMaj.h
index c886ded..6db1253 100644
--- a/include/libMSVMMaj.h
+++ b/include/libMSVMMaj.h
@@ -1,7 +1,6 @@
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
-#include <time.h>
#include <cblas.h>
#include <string.h>
#include "util.h"
@@ -24,3 +23,7 @@ void main_loop(struct Model *model, struct Data *data);
int dposv(char UPLO, int N, int NRHS, double *A, int LDA, double *B, int LDB);
void initialize_weights(struct Data *data, struct Model *model);
+
+void predict_labels(struct Data *data, struct Model *model, long *predy);
+double prediction_perf(struct Data *data, long *predy);
+
diff --git a/include/util.h b/include/util.h
index 0b5009e..ec415ac 100644
--- a/include/util.h
+++ b/include/util.h
@@ -3,6 +3,7 @@
#include <stdlib.h>
#include <math.h>
#include <string.h>
+#include <time.h>
#include "MSVMMaj.h"
#define Calloc(type, n) (type *)calloc((n), sizeof(type))
@@ -11,7 +12,12 @@
#define maximum(a, b) a > b ? a : b
#define minimum(a, b) a < b ? a : b
-void read_data(struct Data *dataset, struct Model *model, char *data_file);
+void read_data(struct Data *dataset, char *data_file);
+
+void read_model(struct Model *model, char *model_filename);
+void write_model(struct Model *model, char *output_filename);
+
+void write_predictions(struct Data *data, long *predy, char *output_filename);
int check_argv(int argc, char **argv, char *str);
int check_argv_eq(int argc, char **argv, char *str);
@@ -34,3 +40,4 @@ void free_model(struct Model *model);
void free_data(struct Data *data);
void print_matrix(double *M, long rows, long cols);
+