aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2016-05-16 21:41:27 +0200
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2016-05-16 21:41:27 +0200
commit2584591266e434140a94d818402745f34fb3acf3 (patch)
tree2faeb4f27b6b4101149e608657d2d018249e73c1
parentmajor refactor of the code (diff)
downloadgensvm-2584591266e434140a94d818402745f34fb3acf3.tar.gz
gensvm-2584591266e434140a94d818402745f34fb3acf3.zip
create a single training function for easy external access
-rw-r--r--Makefile6
-rw-r--r--include/gensvm_optimize.h4
-rw-r--r--include/gensvm_train.h21
-rw-r--r--src/GenSVMtraintest.c45
-rw-r--r--src/gensvm_base.c3
-rw-r--r--src/gensvm_optimize.c2
-rw-r--r--src/gensvm_train.c51
7 files changed, 92 insertions, 40 deletions
diff --git a/Makefile b/Makefile
index d281bb4..4b4fbb2 100644
--- a/Makefile
+++ b/Makefile
@@ -47,7 +47,8 @@ lib/libgensvm.a: \
src/gensvm_strutil.o \
src/gensvm_sv.o \
src/gensvm_task.o \
- src/gensvm_timer.o
+ src/gensvm_timer.o \
+ src/gensvm_train.o
@ar rcs lib/libgensvm.a \
src/gensvm_base.o \
src/gensvm_cmdarg.o \
@@ -67,7 +68,8 @@ lib/libgensvm.a: \
src/gensvm_strutil.o \
src/gensvm_sv.o \
src/gensvm_task.o \
- src/gensvm_timer.o
+ src/gensvm_timer.o \
+ src/gensvm_train.o
@echo libgensvm.a...
gensvm: src/GenSVMtraintest.c lib/libgensvm.a
diff --git a/include/gensvm_optimize.h b/include/gensvm_optimize.h
index 7a23bdb..1041b72 100644
--- a/include/gensvm_optimize.h
+++ b/include/gensvm_optimize.h
@@ -10,8 +10,8 @@
*
*/
-#ifndef GENSVM_TRAIN_H
-#define GENSVM_TRAIN_H
+#ifndef GENSVM_OPTIMIZE_H
+#define GENSVM_OPTIMIZE_H
#include "gensvm_sv.h"
#include "gensvm_print.h"
diff --git a/include/gensvm_train.h b/include/gensvm_train.h
new file mode 100644
index 0000000..c410cef
--- /dev/null
+++ b/include/gensvm_train.h
@@ -0,0 +1,21 @@
+/**
+ * @file gensvm_train.h
+ * @author Gertjan van den Burg
+ * @date May, 2016
+ * @brief Header file for gensvm_train.c
+ *
+ */
+
+#ifndef GENSVM_TRAIN_H
+#define GENSVM_TRAIN_H
+
+// includes
+#include "gensvm_init.h"
+#include "gensvm_kernel.h"
+#include "gensvm_optimize.h"
+
+// function declarations
+void gensvm_train(struct GenModel *model, struct GenData *data,
+ struct GenModel *seed_model);
+
+#endif
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index 47cc900..73c0ce9 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -12,8 +12,7 @@
#include "gensvm_cmdarg.h"
#include "gensvm_io.h"
-#include "gensvm_init.h"
-#include "gensvm_optimize.h"
+#include "gensvm_train.h"
#include "gensvm_pred.h"
#define MINARGS 2
@@ -54,10 +53,8 @@ void exit_with_help()
exit(0);
}
-
int main(int argc, char **argv)
{
- bool with_test = false;
long i, *predy = NULL;
double performance;
@@ -68,6 +65,7 @@ int main(int argc, char **argv)
*prediction_outputfile = NULL;
struct GenModel *model = gensvm_init_model();
+ struct GenModel *seed_model = NULL;
struct GenData *traindata = gensvm_init_data();
struct GenData *testdata = gensvm_init_data();
@@ -77,53 +75,29 @@ int main(int argc, char **argv)
parse_command_line(argc, argv, model, &model_inputfile,
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
- if (testing_inputfile != NULL)
- with_test = true;
// read data from files
gensvm_read_data(traindata, training_inputfile);
- if (with_test)
- gensvm_read_data(testdata, testing_inputfile);
-
- // copy dataset parameters to model
- model->n = traindata->n;
- model->m = traindata->m;
- model->K = traindata->K;
model->data_file = training_inputfile;
- // allocate model
- gensvm_allocate_model(model);
-
- // run pre/post processing in case of kernels
- gensvm_kernel_preprocess(model, traindata);
- if (with_test)
- gensvm_kernel_postprocess(model, traindata, testdata);
-
- // reallocate model in case of kernel dimension reduction
- gensvm_reallocate_model(model, traindata->n, traindata->r);
-
- // initialize weights
- gensvm_initialize_weights(traindata, model);
-
// seed the random number generator
srand(time(NULL));
// load a seed model from file if it is specified
if (gensvm_check_argv_eq(argc, argv, "-s")) {
- struct GenModel *seed_model = gensvm_init_model();
+ seed_model = gensvm_init_model();
gensvm_read_model(seed_model, model_inputfile);
- gensvm_init_V(seed_model, model, traindata);
- gensvm_free_model(seed_model);
- } else {
- gensvm_init_V(NULL, model, traindata);
}
- // start training
- gensvm_optimize(model, traindata);
+ // train the GenSVM model
+ gensvm_train(model, traindata, seed_model);
// if we also have a test set, predict labels and write to predictions
// to an output file if specified
- if (with_test) {
+ if (testing_inputfile != NULL) {
+ gensvm_read_data(testdata, testing_inputfile);
+ gensvm_kernel_postprocess(model, traindata, testdata);
+
// predict labels
predy = Calloc(long, testdata->n);
gensvm_predict_labels(testdata, model, predy);
@@ -154,6 +128,7 @@ int main(int argc, char **argv)
// free everything
gensvm_free_model(model);
+ gensvm_free_model(seed_model);
gensvm_free_data(traindata);
gensvm_free_data(testdata);
free(training_inputfile);
diff --git a/src/gensvm_base.c b/src/gensvm_base.c
index eddef5c..568f19a 100644
--- a/src/gensvm_base.c
+++ b/src/gensvm_base.c
@@ -194,6 +194,9 @@ void gensvm_reallocate_model(struct GenModel *model, long n, long m)
*/
void gensvm_free_model(struct GenModel *model)
{
+ if (model == NULL)
+ return;
+
free(model->W);
free(model->t);
free(model->V);
diff --git a/src/gensvm_optimize.c b/src/gensvm_optimize.c
index 70b3620..464815d 100644
--- a/src/gensvm_optimize.c
+++ b/src/gensvm_optimize.c
@@ -1,5 +1,5 @@
/**
- * @file gensvm_train.c
+ * @file gensvm_optimize.c
* @author Gertjan van den Burg
* @date August 9, 2013
* @brief Main functions for training the GenSVM solution.
diff --git a/src/gensvm_train.c b/src/gensvm_train.c
new file mode 100644
index 0000000..4c0f332
--- /dev/null
+++ b/src/gensvm_train.c
@@ -0,0 +1,51 @@
+/**
+ * @file gensvm_train.c
+ * @author Gertjan van den Burg
+ * @date May, 2016
+ * @brief Main function for training a GenSVM model.
+ *
+ */
+
+#include "gensvm_train.h"
+
+/**
+ * @brief Utility function for training a GenSVM model
+ *
+ * @details
+ * This function organizes model allocation, kernel preprocessing, instance
+ * weight initialization, and model training. It is the function that should
+ * be used for training a single GenSVM model. Note that optionally a seed
+ * model can be passed to the function to seed the V matrix with. When no such
+ * model is used this parameter should be set to NULL.
+ *
+ * @param[in] model a GenModel instance
+ * @param[in] data a GenData instance with the training data
+ * @param[in] seed_model an optional GenModel to seed the V matrix
+ *
+ */
+void gensvm_train(struct GenModel *model, struct GenData *data,
+ struct GenModel *seed_model)
+{
+ // copy dataset parameters to model
+ model->n = data->n;
+ model->m = data->m;
+ model->K = data->K;
+
+ // initialize the V matrix (potentially with a seed model)
+ gensvm_init_V(seed_model, model, data);
+
+ // allocate model
+ gensvm_allocate_model(model);
+
+ // preprocess kernel
+ gensvm_kernel_preprocess(model, data);
+
+ // reallocate model for kernels
+ gensvm_reallocate_model(model, data->n, data->r);
+
+ // initialize weights
+ gensvm_initialize_weights(data, model);
+
+ // start training
+ gensvm_optimize(model, data);
+}