diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2015-01-30 16:22:52 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2015-01-30 16:22:52 +0100 |
| commit | df9c3ca0b62f1a20071bee3a55d24d673c5d11e0 (patch) | |
| tree | d3a2d6be5dfe6e2a4e248ad04dfdbb40852c8f7a /include | |
| parent | update documentation gensvm structs (diff) | |
| download | gensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.tar.gz gensvm-df9c3ca0b62f1a20071bee3a55d24d673c5d11e0.zip | |
first working version of new kernel GenSVM
Diffstat (limited to 'include')
| -rw-r--r-- | include/gensvm.h | 6 | ||||
| -rw-r--r-- | include/gensvm_kernel.h | 24 | ||||
| -rw-r--r-- | include/gensvm_pred.h | 8 | ||||
| -rw-r--r-- | include/gensvm_train_dataset.h | 8 |
4 files changed, 27 insertions, 19 deletions
diff --git a/include/gensvm.h b/include/gensvm.h index ddae3ae..5101b41 100644 --- a/include/gensvm.h +++ b/include/gensvm.h @@ -88,7 +88,9 @@ struct GenData { long n; ///< number of instances long m; - ///< number of predictors + ///< number of predictors (width of RAW) + long r; + ///< number of eigenvalues (width of Z) long *y; ///< array of class labels, 1..K double *Z; @@ -96,7 +98,7 @@ struct GenData { ///< of the kernel matrix) double *RAW; ///< augmented raw data matrix - double *J; + double *Sigma; KernelType kerneltype; double *kernelparam; }; diff --git a/include/gensvm_kernel.h b/include/gensvm_kernel.h index bf46bbc..d5c5e8d 100644 --- a/include/gensvm_kernel.h +++ b/include/gensvm_kernel.h @@ -21,18 +21,22 @@ struct GenData; struct GenModel; // function declarations -void gensvm_make_kernel(struct GenModel *model, struct GenData *data); - -long gensvm_make_eigen(double *K, long n, double **P, double **Lambda); +void gensvm_kernel_preprocess(struct GenModel *model, struct GenData *data); +void gensvm_kernel_postprocess(struct GenModel *model, + struct GenData *traindata, struct GenData *testdata); +void gensvm_make_kernel(struct GenModel *model, struct GenData *data, + double *K); +long gensvm_make_eigen(double *K, long n, double **P, double **Sigma); void gensvm_make_crosskernel(struct GenModel *model, - struct GenData *data_train, struct GenData *data_test, + struct GenData *data_train, struct GenData *data_test, double **K2); +void gensvm_make_trainfactor(struct GenData *data, double *P, double *Sigma, + long r); +void gensvm_make_testfactor(struct GenData *testdata, + struct GenData *traindata, double *K2); +double gensvm_dot_rbf(double *x1, double *x2, double *kernelparam, long n); +double gensvm_dot_poly(double *x1, double *x2, double *kernelparam, long n); +double gensvm_dot_sigmoid(double *x1, double *x2, double *kernelparam, long n); -double gensvm_compute_rbf(double *x1, double *x2, double *kernelparam, - long n); -double gensvm_compute_poly(double *x1, double *x2, double *kernelparam, - long n); -double gensvm_compute_sigmoid(double *x1, double *x2, double *kernelparam, - long n); #endif diff --git a/include/gensvm_pred.h b/include/gensvm_pred.h index 0cce20b..76b3ad3 100644 --- a/include/gensvm_pred.h +++ b/include/gensvm_pred.h @@ -19,14 +19,8 @@ struct GenData; struct GenModel; // function declarations -void gensvm_predict_labels(struct GenData *data_test, - struct GenData *data_train, struct GenModel *model, - long *predy); -void gensvm_predict_labels_linear(struct GenData *data, +void gensvm_predict_labels(struct GenData *testdata, struct GenModel *model, long *predy); -void gensvm_predict_labels_kernel(struct GenData *data_test, - struct GenData *data_train, struct GenModel *model, - long *predy); double gensvm_prediction_perf(struct GenData *data, long *perdy); #endif diff --git a/include/gensvm_train_dataset.h b/include/gensvm_train_dataset.h index 299bc52..0dc4319 100644 --- a/include/gensvm_train_dataset.h +++ b/include/gensvm_train_dataset.h @@ -136,4 +136,12 @@ void make_model_from_task(struct Task *task, struct GenModel *model); void copy_model(struct GenModel *from, struct GenModel *to); void print_progress_string(struct Task *task, long N); + +// new +void start_training(struct Queue *q); +double gensvm_cross_validation(struct GenModel *model, + struct GenData **train_folds, struct GenData **test_folds, + int folds, long n_total); + + #endif |
