diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-02-23 17:10:16 +0000 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-02-23 17:10:16 +0000 |
| commit | bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468 (patch) | |
| tree | dcf17ae3e5fef6cee367f0da985bb4894495aee9 /R | |
| parent | update gensvm C library (diff) | |
| download | rgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.tar.gz rgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.zip | |
Implement fitting and prediction
Diffstat (limited to 'R')
| -rw-r--r-- | R/gensvm.R | 46 | ||||
| -rw-r--r-- | R/predict.gensvm.R | 25 | ||||
| -rw-r--r-- | R/print.gensvm.R | 28 | ||||
| -rw-r--r-- | R/util.labelencoder.R | 1 |
4 files changed, 73 insertions, 27 deletions
@@ -56,9 +56,9 @@ #' \code{\link{plot}}, and \code{\link{gensvm.grid}}. #' #' @export +#' @useDynLib gensvm_wrapper, .registration = TRUE #' #' @examples -#' X <- #' gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, weights='unit', kernel='linear', gamma='auto', coef=0.0, @@ -67,10 +67,10 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, { call <- match.call() - - # TODO: generate the random.seed value in R if it is NULL. Then you can - # return it and people can still reproduce even if they forgot to set it - # explicitly. + # Generate the random.seed value in R if it is NULL. This way users can + # reproduce the run because it is returned in the output object. + if (is.null(random.seed)) + random.seed <- runif(1) * (2**31 - 1) # TODO: Store a labelencoder in the object, preferably as a partially # hidden item. This can then be used with prediction. @@ -79,9 +79,13 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, n.features <- ncol(X) n.classes <- length(unique(y)) - # Convert labels to integers - y.clean <- label.encode(y) + classes <- sort(unique(y)) + y.clean <- match(y, classes) + + # Convert gamma if it is 'auto' + if (gamma == 'auto') + gamma <- 1.0/n.features # Convert weights to index weight.idx <- which(c("unit", "group") == weights) @@ -90,39 +94,43 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, "Valid options are 'unit' and 'group'") } - # Convert kernel to index - kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) + # Convert kernel to index (remember off-by-one for R vs. C) + kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 if (length(kernel.idx) == 0) { stop("Incorrect kernel specification. ", "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") } - out <- .Call("R_gensvm_train", - as.matrix(t(X)), + as.matrix(X), as.integer(y.clean), p, lambda, kappa, epsilon, weight.idx, - kernel.idx, + as.integer(kernel.idx), gamma, coef, degree, kernel.eigen.cutoff, - verbose, - max.iter, - random.seed, - seed.V) - + as.integer(verbose), + as.integer(max.iter), + as.integer(random.seed), + seed.V, + as.integer(n.objects), + as.integer(n.features), + as.integer(n.classes)) - object <- list(call = call, lambda = lambda, kappa = kappa, + object <- list(call = call, p = p, lambda = lambda, kappa = kappa, epsilon = epsilon, weights = weights, kernel = kernel, gamma = gamma, coef = coef, degree = degree, kernel.eigen.cutoff = kernel.eigen.cutoff, random.seed = random.seed, max.iter = max.iter, - V = out$V, n.iter = out$n.iter, n.support = out$n.support) + n.objects = n.objects, n.features = n.features, + n.classes = n.classes, classes = classes, V = out$V, + n.iter = out$n.iter, n.support = out$n.support) class(object) <- "gensvm" + return(object) } diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 6cc8851..5c8f2e7 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -3,7 +3,7 @@ #' @description This function predicts the class labels of new data using a #' fitted GenSVM model. #' -#' @param object Fitted \code{gensvm} object +#' @param fit Fitted \code{gensvm} object #' @param newx Matrix of new values for \code{x} for which predictions need to #' be made. #' @param \dots further arguments are ignored @@ -27,10 +27,27 @@ #' #' #' -predict.gensvm <- function(object, newx, ...) +predict.gensvm <- function(fit, newx, ...) { - # TODO: C library fitting prediction here (or not? with the column-major - # order it may be faster to do it directly in R) + ## Implementation note: + ## - It might seem that it would be faster to do the prediction directly in + ## R here, since we then don't need to switch to C, construct model and + ## data structures, copy the data, etc. before doing the prediction. + ## However, if you actually implement it and compare, we find that calling + ## the C implementation is *much* faster than doing it in R. + + # Sanity check + if (ncol(newx) != fit$n.features) + stop("Number of features of fitted model and supplied data disagree.") + + y.pred.c <- .Call("R_gensvm_predict", + as.matrix(newx), + as.matrix(fit$V), + as.integer(nrow(newx)), + as.integer(ncol(newx)), + as.integer(fit$n.classes) + ) + yhat <- fit$classes[y.pred.c] return(yhat) } diff --git a/R/print.gensvm.R b/R/print.gensvm.R index 8d17b0c..06a3649 100644 --- a/R/print.gensvm.R +++ b/R/print.gensvm.R @@ -26,9 +26,31 @@ print.gensvm <- function(object, ...) { cat("\nCall:\n") dput(object$call) + cat("\nData:\n") + cat("\tn.objects:", object$n.objects, "\n") + cat("\tn.features:", object$n.features, "\n") + cat("\tn.classes:", object$n.classes, "\n") + cat("\tclasses:", object$classes, "\n") + cat("Parameters:\n") + cat("\tp:", object$p, "\n") + cat("\tlambda:", object$lambda, "\n") + cat("\tkappa:", object$kappa, "\n") + cat("\tepsilon:", object$epsilon, "\n") + cat("\tweights:", object$weights, "\n") + cat("\tmax.iter:", object$max.iter, "\n") + cat("\trandom.seed:", object$random.seed, "\n") + cat("\tkernel:", object$kernel, "\n") + if (object$kernel %in% c("poly", "rbf", "sigmoid")) { + cat("\tkernel.eigen.cutoff:", object$kernel.eigen.cutoff, "\n") + cat("\tgamma:", object$gamma, "\n") + } + if (object$kernel %in% c("poly", "sigmoid")) + cat("\tcoef:", object$coef, "\n") + if (object$kernel == 'poly') + cat("\tdegree:", object$degree, "\n") + cat("Results:\n") + cat("\tn.iter:", object$n.iter, "\n") + cat("\tn.support:", object$n.support, "\n") - # TODO: fill this out - # - # invisible(object) } diff --git a/R/util.labelencoder.R b/R/util.labelencoder.R deleted file mode 100644 index 8b13789..0000000 --- a/R/util.labelencoder.R +++ /dev/null @@ -1 +0,0 @@ - |
