aboutsummaryrefslogtreecommitdiff
path: root/R/gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/gensvm.R')
-rw-r--r--R/gensvm.R46
1 files changed, 27 insertions, 19 deletions
diff --git a/R/gensvm.R b/R/gensvm.R
index 1923f06..4a4ab6b 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -56,9 +56,9 @@
#' \code{\link{plot}}, and \code{\link{gensvm.grid}}.
#'
#' @export
+#' @useDynLib gensvm_wrapper, .registration = TRUE
#'
#' @examples
-#' X <-
#'
gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
weights='unit', kernel='linear', gamma='auto', coef=0.0,
@@ -67,10 +67,10 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
{
call <- match.call()
-
- # TODO: generate the random.seed value in R if it is NULL. Then you can
- # return it and people can still reproduce even if they forgot to set it
- # explicitly.
+ # Generate the random.seed value in R if it is NULL. This way users can
+ # reproduce the run because it is returned in the output object.
+ if (is.null(random.seed))
+ random.seed <- runif(1) * (2**31 - 1)
# TODO: Store a labelencoder in the object, preferably as a partially
# hidden item. This can then be used with prediction.
@@ -79,9 +79,13 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
n.features <- ncol(X)
n.classes <- length(unique(y))
-
# Convert labels to integers
- y.clean <- label.encode(y)
+ classes <- sort(unique(y))
+ y.clean <- match(y, classes)
+
+ # Convert gamma if it is 'auto'
+ if (gamma == 'auto')
+ gamma <- 1.0/n.features
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
@@ -90,39 +94,43 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
"Valid options are 'unit' and 'group'")
}
- # Convert kernel to index
- kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel)
+ # Convert kernel to index (remember off-by-one for R vs. C)
+ kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
if (length(kernel.idx) == 0) {
stop("Incorrect kernel specification. ",
"Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
}
-
out <- .Call("R_gensvm_train",
- as.matrix(t(X)),
+ as.matrix(X),
as.integer(y.clean),
p,
lambda,
kappa,
epsilon,
weight.idx,
- kernel.idx,
+ as.integer(kernel.idx),
gamma,
coef,
degree,
kernel.eigen.cutoff,
- verbose,
- max.iter,
- random.seed,
- seed.V)
-
+ as.integer(verbose),
+ as.integer(max.iter),
+ as.integer(random.seed),
+ seed.V,
+ as.integer(n.objects),
+ as.integer(n.features),
+ as.integer(n.classes))
- object <- list(call = call, lambda = lambda, kappa = kappa,
+ object <- list(call = call, p = p, lambda = lambda, kappa = kappa,
epsilon = epsilon, weights = weights, kernel = kernel,
gamma = gamma, coef = coef, degree = degree,
kernel.eigen.cutoff = kernel.eigen.cutoff,
random.seed = random.seed, max.iter = max.iter,
- V = out$V, n.iter = out$n.iter, n.support = out$n.support)
+ n.objects = n.objects, n.features = n.features,
+ n.classes = n.classes, classes = classes, V = out$V,
+ n.iter = out$n.iter, n.support = out$n.support)
class(object) <- "gensvm"
+
return(object)
}