aboutsummaryrefslogtreecommitdiff
path: root/src/GenSVMtraintest.c
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2016-05-16 21:41:27 +0200
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2016-05-16 21:41:27 +0200
commit2584591266e434140a94d818402745f34fb3acf3 (patch)
tree2faeb4f27b6b4101149e608657d2d018249e73c1 /src/GenSVMtraintest.c
parentmajor refactor of the code (diff)
downloadgensvm-2584591266e434140a94d818402745f34fb3acf3.tar.gz
gensvm-2584591266e434140a94d818402745f34fb3acf3.zip
create a single training function for easy external access
Diffstat (limited to 'src/GenSVMtraintest.c')
-rw-r--r--src/GenSVMtraintest.c45
1 files changed, 10 insertions, 35 deletions
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index 47cc900..73c0ce9 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -12,8 +12,7 @@
#include "gensvm_cmdarg.h"
#include "gensvm_io.h"
-#include "gensvm_init.h"
-#include "gensvm_optimize.h"
+#include "gensvm_train.h"
#include "gensvm_pred.h"
#define MINARGS 2
@@ -54,10 +53,8 @@ void exit_with_help()
exit(0);
}
-
int main(int argc, char **argv)
{
- bool with_test = false;
long i, *predy = NULL;
double performance;
@@ -68,6 +65,7 @@ int main(int argc, char **argv)
*prediction_outputfile = NULL;
struct GenModel *model = gensvm_init_model();
+ struct GenModel *seed_model = NULL;
struct GenData *traindata = gensvm_init_data();
struct GenData *testdata = gensvm_init_data();
@@ -77,53 +75,29 @@ int main(int argc, char **argv)
parse_command_line(argc, argv, model, &model_inputfile,
&training_inputfile, &testing_inputfile,
&model_outputfile, &prediction_outputfile);
- if (testing_inputfile != NULL)
- with_test = true;
// read data from files
gensvm_read_data(traindata, training_inputfile);
- if (with_test)
- gensvm_read_data(testdata, testing_inputfile);
-
- // copy dataset parameters to model
- model->n = traindata->n;
- model->m = traindata->m;
- model->K = traindata->K;
model->data_file = training_inputfile;
- // allocate model
- gensvm_allocate_model(model);
-
- // run pre/post processing in case of kernels
- gensvm_kernel_preprocess(model, traindata);
- if (with_test)
- gensvm_kernel_postprocess(model, traindata, testdata);
-
- // reallocate model in case of kernel dimension reduction
- gensvm_reallocate_model(model, traindata->n, traindata->r);
-
- // initialize weights
- gensvm_initialize_weights(traindata, model);
-
// seed the random number generator
srand(time(NULL));
// load a seed model from file if it is specified
if (gensvm_check_argv_eq(argc, argv, "-s")) {
- struct GenModel *seed_model = gensvm_init_model();
+ seed_model = gensvm_init_model();
gensvm_read_model(seed_model, model_inputfile);
- gensvm_init_V(seed_model, model, traindata);
- gensvm_free_model(seed_model);
- } else {
- gensvm_init_V(NULL, model, traindata);
}
- // start training
- gensvm_optimize(model, traindata);
+ // train the GenSVM model
+ gensvm_train(model, traindata, seed_model);
// if we also have a test set, predict labels and write to predictions
// to an output file if specified
- if (with_test) {
+ if (testing_inputfile != NULL) {
+ gensvm_read_data(testdata, testing_inputfile);
+ gensvm_kernel_postprocess(model, traindata, testdata);
+
// predict labels
predy = Calloc(long, testdata->n);
gensvm_predict_labels(testdata, model, predy);
@@ -154,6 +128,7 @@ int main(int argc, char **argv)
// free everything
gensvm_free_model(model);
+ gensvm_free_model(seed_model);
gensvm_free_data(traindata);
gensvm_free_data(testdata);
free(training_inputfile);