aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2017-02-23 22:03:00 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2017-02-23 22:04:54 -0500
commit7acce74e8eeb6a8fcd99af61030977e0fb498f88 (patch)
treed3e3d73f334ee700041d29e4433760206aa8a82e
parentadd additional gcc checks (diff)
downloadgensvm-7acce74e8eeb6a8fcd99af61030977e0fb498f88.tar.gz
gensvm-7acce74e8eeb6a8fcd99af61030977e0fb498f88.zip
Allow setting of the random seed in the model
-rw-r--r--include/gensvm_base.h2
-rw-r--r--src/GenSVMgrid.c17
-rw-r--r--src/GenSVMtraintest.c7
-rw-r--r--src/gensvm_base.c1
-rw-r--r--src/gensvm_copy.c2
-rw-r--r--src/gensvm_train.c6
-rw-r--r--tests/src/test_gensvm_copy.c2
7 files changed, 30 insertions, 7 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h
index 367a1b2..f5f205f 100644
--- a/include/gensvm_base.h
+++ b/include/gensvm_base.h
@@ -142,6 +142,8 @@ struct GenModel {
///< maximum number of iterations of the algorithm
int status;
///< status of the model after training
+ long seed;
+ ///< seed for the random number generator (-1 = random)
};
/**
diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c
index 3907272..782200f 100644
--- a/src/GenSVMgrid.c
+++ b/src/GenSVMgrid.c
@@ -53,7 +53,7 @@ extern FILE *GENSVM_ERROR_FILE;
// function declarations
void exit_with_help(char **argv);
-void parse_command_line(int argc, char **argv, char *input_filename);
+long parse_command_line(int argc, char **argv, char *input_filename);
void read_grid_from_file(char *input_filename, struct GenGrid *grid);
/**
@@ -77,6 +77,7 @@ void exit_with_help(char **argv)
printf("-h | -help : print this help.\n");
printf("-q : quiet mode (no output, not even errors!)\n");
printf("-x : data files are in LibSVM/SVMlight format\n");
+ printf("-z : seed for the random number generator\n");
exit(EXIT_FAILURE);
}
@@ -101,6 +102,7 @@ void exit_with_help(char **argv)
*/
int main(int argc, char **argv)
{
+ long seed;
bool libsvm_format = false;
char input_filename[GENSVM_MAX_LINE_LENGTH];
@@ -112,7 +114,7 @@ int main(int argc, char **argv)
if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help")
|| gensvm_check_argv_eq(argc, argv, "-h") )
exit_with_help(argv);
- parse_command_line(argc, argv, input_filename);
+ seed = parse_command_line(argc, argv, input_filename);
libsvm_format = gensvm_check_argv(argc, argv, "-x");
note("Reading grid file\n");
@@ -154,7 +156,7 @@ int main(int argc, char **argv)
note("Creating queue\n");
gensvm_fill_queue(grid, q, train_data, test_data);
- srand(time(NULL));
+ srand(seed);
note("Starting training\n");
gensvm_train_queue(q);
@@ -186,10 +188,12 @@ int main(int argc, char **argv)
* @param[in] argv array of command line arguments
* @param[in] input_filename pre-allocated buffer for the grid
* filename.
+ * @returns seed for the RNG
*
*/
-void parse_command_line(int argc, char **argv, char *input_filename)
+long parse_command_line(int argc, char **argv, char *input_filename)
{
+ long seed = time(NULL);
int i;
GENSVM_OUTPUT_FILE = stdout;
@@ -208,6 +212,9 @@ void parse_command_line(int argc, char **argv, char *input_filename)
case 'x':
i--;
break;
+ case 'z':
+ seed = atoi(argv[i]);
+ break;
default:
fprintf(stderr, "Unknown option: -%c\n",
argv[i-1][1]);
@@ -219,6 +226,8 @@ void parse_command_line(int argc, char **argv, char *input_filename)
exit_with_help(argv);
strcpy(input_filename, argv[i]);
+
+ return seed;
}
/**
diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c
index 285be2a..63e6d58 100644
--- a/src/GenSVMtraintest.c
+++ b/src/GenSVMtraintest.c
@@ -95,6 +95,7 @@ void exit_with_help(char **argv)
"3=SIGMOID)\n");
printf("-x : data files are in LibSVM/SVMlight "
"format\n");
+ printf("-z seed : seed for the random number generator\n");
printf("\n");
exit(EXIT_FAILURE);
@@ -165,9 +166,6 @@ int main(int argc, char **argv)
gensvm_free_sparse(traindata->spZ);
}
- // seed the random number generator
- srand(time(NULL));
-
// load a seed model from file if it is specified
if (gensvm_check_argv_eq(argc, argv, "-s")) {
seed_model = gensvm_init_model();
@@ -348,6 +346,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model,
case 'x':
i--;
break;
+ case 'z':
+ model->seed = atoi(argv[i]);
+ break;
default:
// this one should always print explicitly to
// stderr, even if '-q' is supplied, because
diff --git a/src/gensvm_base.c b/src/gensvm_base.c
index 24b9c71..00004d2 100644
--- a/src/gensvm_base.c
+++ b/src/gensvm_base.c
@@ -118,6 +118,7 @@ struct GenModel *gensvm_init_model(void)
model->training_error = -1;
model->elapsed_iter = -1;
model->status = -1;
+ model->seed = -1;
model->V = NULL;
model->Vbar = NULL;
diff --git a/src/gensvm_copy.c b/src/gensvm_copy.c
index 0eaaeed..c2a2b59 100644
--- a/src/gensvm_copy.c
+++ b/src/gensvm_copy.c
@@ -41,6 +41,7 @@
* - GenModel::coef
* - GenModel::degree
* - GenModel::max_iter
+ * - GenModel::seed
*
* @param[in] from GenModel to copy parameters from
* @param[in,out] to GenModel to copy parameters to
@@ -59,4 +60,5 @@ void gensvm_copy_model(struct GenModel *from, struct GenModel *to)
to->degree = from->degree;
to->max_iter = from->max_iter;
+ to->seed = from->seed;
}
diff --git a/src/gensvm_train.c b/src/gensvm_train.c
index bc1dad7..5c668e0 100644
--- a/src/gensvm_train.c
+++ b/src/gensvm_train.c
@@ -44,6 +44,8 @@
void gensvm_train(struct GenModel *model, struct GenData *data,
struct GenModel *seed_model)
{
+ long real_seed;
+
// copy dataset parameters to model
model->n = data->n;
model->m = data->m;
@@ -52,6 +54,10 @@ void gensvm_train(struct GenModel *model, struct GenData *data,
// allocate model
gensvm_allocate_model(model);
+ // set the random seed
+ real_seed = (model->seed == -1) ? time(NULL) : model->seed;
+ srand(real_seed);
+
// initialize the V matrix (potentially with a seed model)
gensvm_init_V(seed_model, model, data);
diff --git a/tests/src/test_gensvm_copy.c b/tests/src/test_gensvm_copy.c
index c9acf23..8414a12 100644
--- a/tests/src/test_gensvm_copy.c
+++ b/tests/src/test_gensvm_copy.c
@@ -39,6 +39,7 @@ char *test_copy_model_linear()
from_model->weight_idx = 2;
from_model->kerneltype = K_LINEAR;
from_model->max_iter = 100;
+ from_model->seed = 123;
gensvm_copy_model(from_model, to_model);
@@ -50,6 +51,7 @@ char *test_copy_model_linear()
"to_model->weight_idx incorrect.");
mu_assert(to_model->kerneltype == K_LINEAR, "to->kerneltype incorrect");
mu_assert(to_model->max_iter == 100, "to->max_iter incorrect");
+ mu_assert(to_model->seed == 123, "to->seed incorrect");
gensvm_free_model(from_model);
gensvm_free_model(to_model);