aboutsummaryrefslogtreecommitdiff
path: root/R/gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/gensvm.R')
-rw-r--r--R/gensvm.R92
1 files changed, 75 insertions, 17 deletions
diff --git a/R/gensvm.R b/R/gensvm.R
index 4a4ab6b..40930ac 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -1,7 +1,8 @@
#' @title Fit the GenSVM model
#'
#' @description Fits the Generalized Multiclass Support Vector Machine model
-#' with the given parameters.
+#' with the given parameters. See the package documentation
+#' (\code{\link{gensvm-package}}) for more general information about GenSVM.
#'
#' @param X data matrix with the predictors
#' @param y class labels
@@ -12,7 +13,8 @@
#' @param weights type of instance weights to use. Options are 'unit' for unit
#' weights and 'group' for group size correction weight (eq. 4 in the paper).
#' @param kernel the kernel type to use in the classifier. It must be one of
-#' 'linear', 'poly', 'rbf', or 'sigmoid'.
+#' 'linear', 'poly', 'rbf', or 'sigmoid'. See the section "Kernels in GenSVM"
+#' in \code{\link{gensvm-package}} for more info.
#' @param gamma kernel parameter for the rbf, polynomial, and sigmoid kernel.
#' If gamma is 'auto', then 1/n_features will be used.
#' @param coef parameter for the polynomial and sigmoid kernel.
@@ -25,25 +27,45 @@
#' @param random.seed Seed for the random number generator (useful for
#' reproducible output)
#' @param max.iter Maximum number of iterations of the optimization algorithm.
+#' @param seed.V Matrix to warm-start the optimization algorithm. This is
+#' typically the output of \code{coef(fit)}. Note that this function will
+#' silently drop seed.V if the dimensions don't match the provided data.
#'
#' @return A "gensvm" S3 object is returned for which the print, predict, coef,
#' and plot methods are available. It has the following items:
#' \item{call}{The call that was used to construct the model.}
+#' \item{p}{The value of the lp norm in the loss function}
#' \item{lambda}{The regularization parameter used in the model.}
#' \item{kappa}{The hinge function parameter used.}
#' \item{epsilon}{The stopping criterion used.}
#' \item{weights}{The instance weights type used.}
#' \item{kernel}{The kernel function used.}
-#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}.
+#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}
#' \item{coef}{The value of the coef parameter of the kernel, if applicable}
#' \item{degree}{The degree of the kernel, if applicable}
#' \item{kernel.eigen.cutoff}{The cutoff value of the reduced
-#' eigendecomposition of the kernel matrix}
+#' eigendecomposition of the kernel matrix.}
+#' \item{verbose}{Whether or not the model was fitted with progress output}
#' \item{random.seed}{The random seed used to seed the model.}
#' \item{max.iter}{Maximum number of iterations of the algorithm.}
+#' \item{n.objects}{Number of objects in the dataset}
+#' \item{n.features}{Number of features in the dataset}
+#' \item{n.classes}{Number of classes in the dataset}
+#' \item{classes}{Array with the actual class labels}
+#' \item{V}{Coefficient matrix}
+#' \item{n.iter}{Number of iterations performed in training}
+#' \item{n.support}{Number of support vectors in the final model}
+#' \item{training.time}{Total training time}
+#' \item{X.train}{When training with nonlinear kernels, the training data is
+#' needed to perform prediction. For these kernels it is therefore stored in
+#' the fitted model.}
+#'
+#' @note
+#' This function returns partial results when the computation is interrupted by
+#' the user.
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -59,10 +81,37 @@
#' @useDynLib gensvm_wrapper, .registration = TRUE
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # fit using the default parameters
+#' fit <- gensvm(x, y)
+#'
+#' # fit and show progress
+#' fit <- gensvm(x, y, verbose=T)
+#'
+#' # fit with some changed parameters
+#' fit <- gensvm(x, y, lambda=1e-8)
+#'
+#' # Early stopping defined through epsilon
+#' fit <- gensvm(x, y, epsilon=1e-3)
+#'
+#' # Early stopping defined through max.iter
+#' fit <- gensvm(x, y, max.iter=1000)
#'
-gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
- weights='unit', kernel='linear', gamma='auto', coef=0.0,
- degree=2.0, kernel.eigen.cutoff=1e-8, verbose=0,
+#' # Nonlinear training
+#' fit <- gensvm(x, y, kernel='rbf')
+#' fit <- gensvm(x, y, kernel='poly', degree=2, gamma=1.0)
+#'
+#' # Setting the random seed and comparing results
+#' fit <- gensvm(x, y, random.seed=123)
+#' fit2 <- gensvm(x, y, random.seed=123)
+#' all.equal(coef(fit), coef(fit2))
+#'
+#'
+gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
+ weights='unit', kernel='linear', gamma='auto', coef=1.0,
+ degree=2.0, kernel.eigen.cutoff=1e-8, verbose=FALSE,
random.seed=NULL, max.iter=1e8, seed.V=NULL)
{
call <- match.call()
@@ -72,9 +121,6 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
if (is.null(random.seed))
random.seed <- runif(1) * (2**31 - 1)
- # TODO: Store a labelencoder in the object, preferably as a partially
- # hidden item. This can then be used with prediction.
-
n.objects <- nrow(X)
n.features <- ncol(X)
n.classes <- length(unique(y))
@@ -90,17 +136,23 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
if (length(weight.idx) == 0) {
- stop("Incorrect weight specification. ",
+ cat("Error: Incorrect weight specification. ",
"Valid options are 'unit' and 'group'")
+ return
}
# Convert kernel to index (remember off-by-one for R vs. C)
kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
if (length(kernel.idx) == 0) {
- stop("Incorrect kernel specification. ",
+ cat("Error: Incorrect kernel specification. ",
"Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
+ return
}
+ seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V)
+ seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V)
+
+ # Call the C train routine
out <- .Call("R_gensvm_train",
as.matrix(X),
as.integer(y.clean),
@@ -118,18 +170,24 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
as.integer(max.iter),
as.integer(random.seed),
seed.V,
+ as.integer(seed.rows),
+ as.integer(seed.cols),
as.integer(n.objects),
as.integer(n.features),
as.integer(n.classes))
+ # build the output object
object <- list(call = call, p = p, lambda = lambda, kappa = kappa,
epsilon = epsilon, weights = weights, kernel = kernel,
gamma = gamma, coef = coef, degree = degree,
kernel.eigen.cutoff = kernel.eigen.cutoff,
- random.seed = random.seed, max.iter = max.iter,
- n.objects = n.objects, n.features = n.features,
- n.classes = n.classes, classes = classes, V = out$V,
- n.iter = out$n.iter, n.support = out$n.support)
+ verbose = verbose, random.seed = random.seed,
+ max.iter = max.iter, n.objects = n.objects,
+ n.features = n.features, n.classes = n.classes,
+ classes = classes, V = out$V, n.iter = out$n.iter,
+ n.support = out$n.support,
+ training.time = out$training.time,
+ X.train = if(kernel == 'linear') NULL else X)
class(object) <- "gensvm"
return(object)