diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2017-02-23 22:03:00 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2017-02-23 22:04:54 -0500 |
| commit | 7acce74e8eeb6a8fcd99af61030977e0fb498f88 (patch) | |
| tree | d3e3d73f334ee700041d29e4433760206aa8a82e | |
| parent | add additional gcc checks (diff) | |
| download | gensvm-7acce74e8eeb6a8fcd99af61030977e0fb498f88.tar.gz gensvm-7acce74e8eeb6a8fcd99af61030977e0fb498f88.zip | |
Allow setting of the random seed in the model
| -rw-r--r-- | include/gensvm_base.h | 2 | ||||
| -rw-r--r-- | src/GenSVMgrid.c | 17 | ||||
| -rw-r--r-- | src/GenSVMtraintest.c | 7 | ||||
| -rw-r--r-- | src/gensvm_base.c | 1 | ||||
| -rw-r--r-- | src/gensvm_copy.c | 2 | ||||
| -rw-r--r-- | src/gensvm_train.c | 6 | ||||
| -rw-r--r-- | tests/src/test_gensvm_copy.c | 2 |
7 files changed, 30 insertions, 7 deletions
diff --git a/include/gensvm_base.h b/include/gensvm_base.h index 367a1b2..f5f205f 100644 --- a/include/gensvm_base.h +++ b/include/gensvm_base.h @@ -142,6 +142,8 @@ struct GenModel { ///< maximum number of iterations of the algorithm int status; ///< status of the model after training + long seed; + ///< seed for the random number generator (-1 = random) }; /** diff --git a/src/GenSVMgrid.c b/src/GenSVMgrid.c index 3907272..782200f 100644 --- a/src/GenSVMgrid.c +++ b/src/GenSVMgrid.c @@ -53,7 +53,7 @@ extern FILE *GENSVM_ERROR_FILE; // function declarations void exit_with_help(char **argv); -void parse_command_line(int argc, char **argv, char *input_filename); +long parse_command_line(int argc, char **argv, char *input_filename); void read_grid_from_file(char *input_filename, struct GenGrid *grid); /** @@ -77,6 +77,7 @@ void exit_with_help(char **argv) printf("-h | -help : print this help.\n"); printf("-q : quiet mode (no output, not even errors!)\n"); printf("-x : data files are in LibSVM/SVMlight format\n"); + printf("-z : seed for the random number generator\n"); exit(EXIT_FAILURE); } @@ -101,6 +102,7 @@ void exit_with_help(char **argv) */ int main(int argc, char **argv) { + long seed; bool libsvm_format = false; char input_filename[GENSVM_MAX_LINE_LENGTH]; @@ -112,7 +114,7 @@ int main(int argc, char **argv) if (argc < MINARGS || gensvm_check_argv(argc, argv, "-help") || gensvm_check_argv_eq(argc, argv, "-h") ) exit_with_help(argv); - parse_command_line(argc, argv, input_filename); + seed = parse_command_line(argc, argv, input_filename); libsvm_format = gensvm_check_argv(argc, argv, "-x"); note("Reading grid file\n"); @@ -154,7 +156,7 @@ int main(int argc, char **argv) note("Creating queue\n"); gensvm_fill_queue(grid, q, train_data, test_data); - srand(time(NULL)); + srand(seed); note("Starting training\n"); gensvm_train_queue(q); @@ -186,10 +188,12 @@ int main(int argc, char **argv) * @param[in] argv array of command line arguments * @param[in] input_filename pre-allocated buffer for the grid * filename. + * @returns seed for the RNG * */ -void parse_command_line(int argc, char **argv, char *input_filename) +long parse_command_line(int argc, char **argv, char *input_filename) { + long seed = time(NULL); int i; GENSVM_OUTPUT_FILE = stdout; @@ -208,6 +212,9 @@ void parse_command_line(int argc, char **argv, char *input_filename) case 'x': i--; break; + case 'z': + seed = atoi(argv[i]); + break; default: fprintf(stderr, "Unknown option: -%c\n", argv[i-1][1]); @@ -219,6 +226,8 @@ void parse_command_line(int argc, char **argv, char *input_filename) exit_with_help(argv); strcpy(input_filename, argv[i]); + + return seed; } /** diff --git a/src/GenSVMtraintest.c b/src/GenSVMtraintest.c index 285be2a..63e6d58 100644 --- a/src/GenSVMtraintest.c +++ b/src/GenSVMtraintest.c @@ -95,6 +95,7 @@ void exit_with_help(char **argv) "3=SIGMOID)\n"); printf("-x : data files are in LibSVM/SVMlight " "format\n"); + printf("-z seed : seed for the random number generator\n"); printf("\n"); exit(EXIT_FAILURE); @@ -165,9 +166,6 @@ int main(int argc, char **argv) gensvm_free_sparse(traindata->spZ); } - // seed the random number generator - srand(time(NULL)); - // load a seed model from file if it is specified if (gensvm_check_argv_eq(argc, argv, "-s")) { seed_model = gensvm_init_model(); @@ -348,6 +346,9 @@ void parse_command_line(int argc, char **argv, struct GenModel *model, case 'x': i--; break; + case 'z': + model->seed = atoi(argv[i]); + break; default: // this one should always print explicitly to // stderr, even if '-q' is supplied, because diff --git a/src/gensvm_base.c b/src/gensvm_base.c index 24b9c71..00004d2 100644 --- a/src/gensvm_base.c +++ b/src/gensvm_base.c @@ -118,6 +118,7 @@ struct GenModel *gensvm_init_model(void) model->training_error = -1; model->elapsed_iter = -1; model->status = -1; + model->seed = -1; model->V = NULL; model->Vbar = NULL; diff --git a/src/gensvm_copy.c b/src/gensvm_copy.c index 0eaaeed..c2a2b59 100644 --- a/src/gensvm_copy.c +++ b/src/gensvm_copy.c @@ -41,6 +41,7 @@ * - GenModel::coef * - GenModel::degree * - GenModel::max_iter + * - GenModel::seed * * @param[in] from GenModel to copy parameters from * @param[in,out] to GenModel to copy parameters to @@ -59,4 +60,5 @@ void gensvm_copy_model(struct GenModel *from, struct GenModel *to) to->degree = from->degree; to->max_iter = from->max_iter; + to->seed = from->seed; } diff --git a/src/gensvm_train.c b/src/gensvm_train.c index bc1dad7..5c668e0 100644 --- a/src/gensvm_train.c +++ b/src/gensvm_train.c @@ -44,6 +44,8 @@ void gensvm_train(struct GenModel *model, struct GenData *data, struct GenModel *seed_model) { + long real_seed; + // copy dataset parameters to model model->n = data->n; model->m = data->m; @@ -52,6 +54,10 @@ void gensvm_train(struct GenModel *model, struct GenData *data, // allocate model gensvm_allocate_model(model); + // set the random seed + real_seed = (model->seed == -1) ? time(NULL) : model->seed; + srand(real_seed); + // initialize the V matrix (potentially with a seed model) gensvm_init_V(seed_model, model, data); diff --git a/tests/src/test_gensvm_copy.c b/tests/src/test_gensvm_copy.c index c9acf23..8414a12 100644 --- a/tests/src/test_gensvm_copy.c +++ b/tests/src/test_gensvm_copy.c @@ -39,6 +39,7 @@ char *test_copy_model_linear() from_model->weight_idx = 2; from_model->kerneltype = K_LINEAR; from_model->max_iter = 100; + from_model->seed = 123; gensvm_copy_model(from_model, to_model); @@ -50,6 +51,7 @@ char *test_copy_model_linear() "to_model->weight_idx incorrect."); mu_assert(to_model->kerneltype == K_LINEAR, "to->kerneltype incorrect"); mu_assert(to_model->max_iter == 100, "to->max_iter incorrect"); + mu_assert(to_model->seed == 123, "to->seed incorrect"); gensvm_free_model(from_model); gensvm_free_model(to_model); |
