diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-04-03 23:21:12 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-04-03 23:21:12 +0100 |
| commit | f6c4360090c1235ce54f84ddc8ab6a2629f8d8e4 (patch) | |
| tree | cbffa0c4c2a5f22ad7e1c379274488de5421a7bc | |
| parent | Add fitted for grid class (diff) | |
| download | rgensvm-f6c4360090c1235ce54f84ddc8ab6a2629f8d8e4.tar.gz rgensvm-f6c4360090c1235ce54f84ddc8ab6a2629f8d8e4.zip | |
Add support for raw weight vector in gensvm()
| -rw-r--r-- | R/gensvm.R | 11 | ||||
| -rw-r--r-- | R/validate.R | 2 | ||||
| m--------- | src/gensvm | 0 | ||||
| -rw-r--r-- | src/gensvm_wrapper.c | 21 |
4 files changed, 27 insertions, 7 deletions
@@ -139,13 +139,21 @@ gensvm <- function(x, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, if (gamma == 'auto') gamma <- 1.0/n.features + raw.weights <- if (is.character(weights)) NULL else weights + weights <- if (is.character(weights)) weights else "raw" + + if (weights == "raw" && length(raw.weights) != n.objects) { + cat("Error: length of weights vector unequal to number of objects\n") + return(invisible(NULL)) + } + if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda, epsilon=epsilon, gamma=gamma, weights=weights, kernel=kernel)) return(invisible(NULL)) # Convert weights to index - weight.idx <- which(c("unit", "group") == weights) + weight.idx <- which(c("raw", "unit", "group") == weights) - 1 # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 @@ -162,6 +170,7 @@ gensvm <- function(x, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, kappa, epsilon, weight.idx, + raw.weights, as.integer(kernel.idx), gamma, coef, diff --git a/R/validate.R b/R/validate.R index b0f3f39..5960a38 100644 --- a/R/validate.R +++ b/R/validate.R @@ -64,7 +64,7 @@ gensvm.param.conditions <- function() lambda=function(x) {x > 0.0 }, epsilon=function(x) { x > 0.0 }, gamma=function(x) { x != 0.0 }, - weights=function(x) { x %in% c("unit", "group") }, + weights=function(x) { x %in% c("raw", "unit", "group") }, kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } ) } diff --git a/src/gensvm b/src/gensvm -Subproject d80d1570edd794c66be69f6d359c4947b7d7839 +Subproject 1fd9dad0ed6d5299c8d75de6213ed40b183c583 diff --git a/src/gensvm_wrapper.c b/src/gensvm_wrapper.c index 46578d0..275f797 100644 --- a/src/gensvm_wrapper.c +++ b/src/gensvm_wrapper.c @@ -37,10 +37,11 @@ SEXP R_gensvm_train( SEXP R_X, SEXP R_y, SEXP R_p, SEXP R_lambda, SEXP R_kappa, SEXP R_epsilon, SEXP R_weight_idx, - SEXP R_kernel_idx, SEXP R_gamma, SEXP R_coef, SEXP R_degree, - SEXP R_kernel_eigen_cutoff, SEXP R_verbose, SEXP R_max_iter, - SEXP R_random_seed, SEXP R_seed_V, SEXP R_seed_rows, - SEXP R_seed_cols, SEXP R_n, SEXP R_m, SEXP R_K); + SEXP R_raw_weights, SEXP R_kernel_idx, SEXP R_gamma, + SEXP R_coef, SEXP R_degree, SEXP R_kernel_eigen_cutoff, + SEXP R_verbose, SEXP R_max_iter, SEXP R_random_seed, + SEXP R_seed_V, SEXP R_seed_rows, SEXP R_seed_cols, SEXP R_n, + SEXP R_m, SEXP R_K); SEXP R_gensvm_predict(SEXP R_Xtest, SEXP R_V, SEXP R_n, SEXP R_m, SEXP R_K); SEXP R_gensvm_predict_kernels( SEXP R_Xtest, SEXP R_Xtrain, SEXP R_V, SEXP R_V_row, @@ -63,7 +64,7 @@ struct GenData *_build_gensvm_data(double *X, int *y, int n, int m, int K); // Start R package stuff R_CallMethodDef callMethods[] = { - {"R_gensvm_train", (DL_FUNC) &R_gensvm_train, 21}, + {"R_gensvm_train", (DL_FUNC) &R_gensvm_train, 22}, {"R_gensvm_predict", (DL_FUNC) &R_gensvm_predict, 5}, {"R_gensvm_predict_kernels", (DL_FUNC) &R_gensvm_predict_kernels, 14}, {"R_gensvm_plotdata_kernels", (DL_FUNC) &R_gensvm_plotdata_kernels, 14}, @@ -171,6 +172,7 @@ SEXP R_gensvm_train( SEXP R_kappa, SEXP R_epsilon, SEXP R_weight_idx, + SEXP R_raw_weights, SEXP R_kernel_idx, SEXP R_gamma, SEXP R_coef, @@ -194,6 +196,7 @@ SEXP R_gensvm_train( double kappa = *REAL(R_kappa); double epsilon = *REAL(R_epsilon); int weight_idx = *INTEGER(R_weight_idx); + double *raw_weights = isNull(R_raw_weights) ? NULL : REAL(R_raw_weights); int kernel_idx = *INTEGER(R_kernel_idx); double gamma = *REAL(R_gamma); double coef = *REAL(R_coef); @@ -217,6 +220,9 @@ SEXP R_gensvm_train( double value; // Set model parameters from function input arguments + model->n = n; + model->m = m; + model->K = K; model->p = p; model->lambda = lambda; model->kappa = kappa; @@ -230,6 +236,11 @@ SEXP R_gensvm_train( model->max_iter = max_iter; model->seed = random_seed; + if (raw_weights != NULL) { + model->rho = Calloc(double, n); + for (i=0; i<n; i++) model->rho[i] = raw_weights[i]; + } + if (seed_V != NULL) { seed_model = gensvm_init_model(); |
