diff options
Diffstat (limited to 'R/gensvm.R')
| -rw-r--r-- | R/gensvm.R | 92 |
1 files changed, 75 insertions, 17 deletions
@@ -1,7 +1,8 @@ #' @title Fit the GenSVM model #' #' @description Fits the Generalized Multiclass Support Vector Machine model -#' with the given parameters. +#' with the given parameters. See the package documentation +#' (\code{\link{gensvm-package}}) for more general information about GenSVM. #' #' @param X data matrix with the predictors #' @param y class labels @@ -12,7 +13,8 @@ #' @param weights type of instance weights to use. Options are 'unit' for unit #' weights and 'group' for group size correction weight (eq. 4 in the paper). #' @param kernel the kernel type to use in the classifier. It must be one of -#' 'linear', 'poly', 'rbf', or 'sigmoid'. +#' 'linear', 'poly', 'rbf', or 'sigmoid'. See the section "Kernels in GenSVM" +#' in \code{\link{gensvm-package}} for more info. #' @param gamma kernel parameter for the rbf, polynomial, and sigmoid kernel. #' If gamma is 'auto', then 1/n_features will be used. #' @param coef parameter for the polynomial and sigmoid kernel. @@ -25,25 +27,45 @@ #' @param random.seed Seed for the random number generator (useful for #' reproducible output) #' @param max.iter Maximum number of iterations of the optimization algorithm. +#' @param seed.V Matrix to warm-start the optimization algorithm. This is +#' typically the output of \code{coef(fit)}. Note that this function will +#' silently drop seed.V if the dimensions don't match the provided data. #' #' @return A "gensvm" S3 object is returned for which the print, predict, coef, #' and plot methods are available. It has the following items: #' \item{call}{The call that was used to construct the model.} +#' \item{p}{The value of the lp norm in the loss function} #' \item{lambda}{The regularization parameter used in the model.} #' \item{kappa}{The hinge function parameter used.} #' \item{epsilon}{The stopping criterion used.} #' \item{weights}{The instance weights type used.} #' \item{kernel}{The kernel function used.} -#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}. +#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable} #' \item{coef}{The value of the coef parameter of the kernel, if applicable} #' \item{degree}{The degree of the kernel, if applicable} #' \item{kernel.eigen.cutoff}{The cutoff value of the reduced -#' eigendecomposition of the kernel matrix} +#' eigendecomposition of the kernel matrix.} +#' \item{verbose}{Whether or not the model was fitted with progress output} #' \item{random.seed}{The random seed used to seed the model.} #' \item{max.iter}{Maximum number of iterations of the algorithm.} +#' \item{n.objects}{Number of objects in the dataset} +#' \item{n.features}{Number of features in the dataset} +#' \item{n.classes}{Number of classes in the dataset} +#' \item{classes}{Array with the actual class labels} +#' \item{V}{Coefficient matrix} +#' \item{n.iter}{Number of iterations performed in training} +#' \item{n.support}{Number of support vectors in the final model} +#' \item{training.time}{Total training time} +#' \item{X.train}{When training with nonlinear kernels, the training data is +#' needed to perform prediction. For these kernels it is therefore stored in +#' the fitted model.} +#' +#' @note +#' This function returns partial results when the computation is interrupted by +#' the user. #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -59,10 +81,37 @@ #' @useDynLib gensvm_wrapper, .registration = TRUE #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # fit using the default parameters +#' fit <- gensvm(x, y) +#' +#' # fit and show progress +#' fit <- gensvm(x, y, verbose=T) +#' +#' # fit with some changed parameters +#' fit <- gensvm(x, y, lambda=1e-8) +#' +#' # Early stopping defined through epsilon +#' fit <- gensvm(x, y, epsilon=1e-3) +#' +#' # Early stopping defined through max.iter +#' fit <- gensvm(x, y, max.iter=1000) #' -gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, - weights='unit', kernel='linear', gamma='auto', coef=0.0, - degree=2.0, kernel.eigen.cutoff=1e-8, verbose=0, +#' # Nonlinear training +#' fit <- gensvm(x, y, kernel='rbf') +#' fit <- gensvm(x, y, kernel='poly', degree=2, gamma=1.0) +#' +#' # Setting the random seed and comparing results +#' fit <- gensvm(x, y, random.seed=123) +#' fit2 <- gensvm(x, y, random.seed=123) +#' all.equal(coef(fit), coef(fit2)) +#' +#' +gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, + weights='unit', kernel='linear', gamma='auto', coef=1.0, + degree=2.0, kernel.eigen.cutoff=1e-8, verbose=FALSE, random.seed=NULL, max.iter=1e8, seed.V=NULL) { call <- match.call() @@ -72,9 +121,6 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, if (is.null(random.seed)) random.seed <- runif(1) * (2**31 - 1) - # TODO: Store a labelencoder in the object, preferably as a partially - # hidden item. This can then be used with prediction. - n.objects <- nrow(X) n.features <- ncol(X) n.classes <- length(unique(y)) @@ -90,17 +136,23 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, # Convert weights to index weight.idx <- which(c("unit", "group") == weights) if (length(weight.idx) == 0) { - stop("Incorrect weight specification. ", + cat("Error: Incorrect weight specification. ", "Valid options are 'unit' and 'group'") + return } # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 if (length(kernel.idx) == 0) { - stop("Incorrect kernel specification. ", + cat("Error: Incorrect kernel specification. ", "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") + return } + seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V) + seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V) + + # Call the C train routine out <- .Call("R_gensvm_train", as.matrix(X), as.integer(y.clean), @@ -118,18 +170,24 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, as.integer(max.iter), as.integer(random.seed), seed.V, + as.integer(seed.rows), + as.integer(seed.cols), as.integer(n.objects), as.integer(n.features), as.integer(n.classes)) + # build the output object object <- list(call = call, p = p, lambda = lambda, kappa = kappa, epsilon = epsilon, weights = weights, kernel = kernel, gamma = gamma, coef = coef, degree = degree, kernel.eigen.cutoff = kernel.eigen.cutoff, - random.seed = random.seed, max.iter = max.iter, - n.objects = n.objects, n.features = n.features, - n.classes = n.classes, classes = classes, V = out$V, - n.iter = out$n.iter, n.support = out$n.support) + verbose = verbose, random.seed = random.seed, + max.iter = max.iter, n.objects = n.objects, + n.features = n.features, n.classes = n.classes, + classes = classes, V = out$V, n.iter = out$n.iter, + n.support = out$n.support, + training.time = out$training.time, + X.train = if(kernel == 'linear') NULL else X) class(object) <- "gensvm" return(object) |
