diff options
Diffstat (limited to 'R/gensvm.R')
| -rw-r--r-- | R/gensvm.R | 46 |
1 files changed, 27 insertions, 19 deletions
@@ -56,9 +56,9 @@ #' \code{\link{plot}}, and \code{\link{gensvm.grid}}. #' #' @export +#' @useDynLib gensvm_wrapper, .registration = TRUE #' #' @examples -#' X <- #' gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, weights='unit', kernel='linear', gamma='auto', coef=0.0, @@ -67,10 +67,10 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, { call <- match.call() - - # TODO: generate the random.seed value in R if it is NULL. Then you can - # return it and people can still reproduce even if they forgot to set it - # explicitly. + # Generate the random.seed value in R if it is NULL. This way users can + # reproduce the run because it is returned in the output object. + if (is.null(random.seed)) + random.seed <- runif(1) * (2**31 - 1) # TODO: Store a labelencoder in the object, preferably as a partially # hidden item. This can then be used with prediction. @@ -79,9 +79,13 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, n.features <- ncol(X) n.classes <- length(unique(y)) - # Convert labels to integers - y.clean <- label.encode(y) + classes <- sort(unique(y)) + y.clean <- match(y, classes) + + # Convert gamma if it is 'auto' + if (gamma == 'auto') + gamma <- 1.0/n.features # Convert weights to index weight.idx <- which(c("unit", "group") == weights) @@ -90,39 +94,43 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, "Valid options are 'unit' and 'group'") } - # Convert kernel to index - kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) + # Convert kernel to index (remember off-by-one for R vs. C) + kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 if (length(kernel.idx) == 0) { stop("Incorrect kernel specification. ", "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") } - out <- .Call("R_gensvm_train", - as.matrix(t(X)), + as.matrix(X), as.integer(y.clean), p, lambda, kappa, epsilon, weight.idx, - kernel.idx, + as.integer(kernel.idx), gamma, coef, degree, kernel.eigen.cutoff, - verbose, - max.iter, - random.seed, - seed.V) - + as.integer(verbose), + as.integer(max.iter), + as.integer(random.seed), + seed.V, + as.integer(n.objects), + as.integer(n.features), + as.integer(n.classes)) - object <- list(call = call, lambda = lambda, kappa = kappa, + object <- list(call = call, p = p, lambda = lambda, kappa = kappa, epsilon = epsilon, weights = weights, kernel = kernel, gamma = gamma, coef = coef, degree = degree, kernel.eigen.cutoff = kernel.eigen.cutoff, random.seed = random.seed, max.iter = max.iter, - V = out$V, n.iter = out$n.iter, n.support = out$n.support) + n.objects = n.objects, n.features = n.features, + n.classes = n.classes, classes = classes, V = out$V, + n.iter = out$n.iter, n.support = out$n.support) class(object) <- "gensvm" + return(object) } |
