aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-03 23:21:12 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-03 23:21:12 +0100
commitf6c4360090c1235ce54f84ddc8ab6a2629f8d8e4 (patch)
treecbffa0c4c2a5f22ad7e1c379274488de5421a7bc
parentAdd fitted for grid class (diff)
downloadrgensvm-f6c4360090c1235ce54f84ddc8ab6a2629f8d8e4.tar.gz
rgensvm-f6c4360090c1235ce54f84ddc8ab6a2629f8d8e4.zip
Add support for raw weight vector in gensvm()
-rw-r--r--R/gensvm.R11
-rw-r--r--R/validate.R2
m---------src/gensvm0
-rw-r--r--src/gensvm_wrapper.c21
4 files changed, 27 insertions, 7 deletions
diff --git a/R/gensvm.R b/R/gensvm.R
index 8994d3a..5652a4c 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -139,13 +139,21 @@ gensvm <- function(x, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
if (gamma == 'auto')
gamma <- 1.0/n.features
+ raw.weights <- if (is.character(weights)) NULL else weights
+ weights <- if (is.character(weights)) weights else "raw"
+
+ if (weights == "raw" && length(raw.weights) != n.objects) {
+ cat("Error: length of weights vector unequal to number of objects\n")
+ return(invisible(NULL))
+ }
+
if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda,
epsilon=epsilon, gamma=gamma, weights=weights,
kernel=kernel))
return(invisible(NULL))
# Convert weights to index
- weight.idx <- which(c("unit", "group") == weights)
+ weight.idx <- which(c("raw", "unit", "group") == weights) - 1
# Convert kernel to index (remember off-by-one for R vs. C)
kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
@@ -162,6 +170,7 @@ gensvm <- function(x, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
kappa,
epsilon,
weight.idx,
+ raw.weights,
as.integer(kernel.idx),
gamma,
coef,
diff --git a/R/validate.R b/R/validate.R
index b0f3f39..5960a38 100644
--- a/R/validate.R
+++ b/R/validate.R
@@ -64,7 +64,7 @@ gensvm.param.conditions <- function()
lambda=function(x) {x > 0.0 },
epsilon=function(x) { x > 0.0 },
gamma=function(x) { x != 0.0 },
- weights=function(x) { x %in% c("unit", "group") },
+ weights=function(x) { x %in% c("raw", "unit", "group") },
kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
)
}
diff --git a/src/gensvm b/src/gensvm
-Subproject d80d1570edd794c66be69f6d359c4947b7d7839
+Subproject 1fd9dad0ed6d5299c8d75de6213ed40b183c583
diff --git a/src/gensvm_wrapper.c b/src/gensvm_wrapper.c
index 46578d0..275f797 100644
--- a/src/gensvm_wrapper.c
+++ b/src/gensvm_wrapper.c
@@ -37,10 +37,11 @@
SEXP R_gensvm_train( SEXP R_X, SEXP R_y, SEXP R_p, SEXP R_lambda,
SEXP R_kappa, SEXP R_epsilon, SEXP R_weight_idx,
- SEXP R_kernel_idx, SEXP R_gamma, SEXP R_coef, SEXP R_degree,
- SEXP R_kernel_eigen_cutoff, SEXP R_verbose, SEXP R_max_iter,
- SEXP R_random_seed, SEXP R_seed_V, SEXP R_seed_rows,
- SEXP R_seed_cols, SEXP R_n, SEXP R_m, SEXP R_K);
+ SEXP R_raw_weights, SEXP R_kernel_idx, SEXP R_gamma,
+ SEXP R_coef, SEXP R_degree, SEXP R_kernel_eigen_cutoff,
+ SEXP R_verbose, SEXP R_max_iter, SEXP R_random_seed,
+ SEXP R_seed_V, SEXP R_seed_rows, SEXP R_seed_cols, SEXP R_n,
+ SEXP R_m, SEXP R_K);
SEXP R_gensvm_predict(SEXP R_Xtest, SEXP R_V, SEXP R_n, SEXP R_m, SEXP R_K);
SEXP R_gensvm_predict_kernels(
SEXP R_Xtest, SEXP R_Xtrain, SEXP R_V, SEXP R_V_row,
@@ -63,7 +64,7 @@ struct GenData *_build_gensvm_data(double *X, int *y, int n, int m, int K);
// Start R package stuff
R_CallMethodDef callMethods[] = {
- {"R_gensvm_train", (DL_FUNC) &R_gensvm_train, 21},
+ {"R_gensvm_train", (DL_FUNC) &R_gensvm_train, 22},
{"R_gensvm_predict", (DL_FUNC) &R_gensvm_predict, 5},
{"R_gensvm_predict_kernels", (DL_FUNC) &R_gensvm_predict_kernels, 14},
{"R_gensvm_plotdata_kernels", (DL_FUNC) &R_gensvm_plotdata_kernels, 14},
@@ -171,6 +172,7 @@ SEXP R_gensvm_train(
SEXP R_kappa,
SEXP R_epsilon,
SEXP R_weight_idx,
+ SEXP R_raw_weights,
SEXP R_kernel_idx,
SEXP R_gamma,
SEXP R_coef,
@@ -194,6 +196,7 @@ SEXP R_gensvm_train(
double kappa = *REAL(R_kappa);
double epsilon = *REAL(R_epsilon);
int weight_idx = *INTEGER(R_weight_idx);
+ double *raw_weights = isNull(R_raw_weights) ? NULL : REAL(R_raw_weights);
int kernel_idx = *INTEGER(R_kernel_idx);
double gamma = *REAL(R_gamma);
double coef = *REAL(R_coef);
@@ -217,6 +220,9 @@ SEXP R_gensvm_train(
double value;
// Set model parameters from function input arguments
+ model->n = n;
+ model->m = m;
+ model->K = K;
model->p = p;
model->lambda = lambda;
model->kappa = kappa;
@@ -230,6 +236,11 @@ SEXP R_gensvm_train(
model->max_iter = max_iter;
model->seed = random_seed;
+ if (raw_weights != NULL) {
+ model->rho = Calloc(double, n);
+ for (i=0; i<n; i++) model->rho[i] = raw_weights[i];
+ }
+
if (seed_V != NULL) {
seed_model = gensvm_init_model();