diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
| commit | 004941896bac692d354c41a3334d20ee1d4627f7 (patch) | |
| tree | 2b11e42d8524843409e2bf8deb4ceb74c8b69347 /R | |
| parent | updates to GenSVM C library (diff) | |
| download | rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.tar.gz rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.zip | |
GenSVM R package
Diffstat (limited to 'R')
| -rw-r--r-- | R/coef.gensvm.R | 6 | ||||
| -rw-r--r-- | R/coef.gensvm.grid.R | 32 | ||||
| -rw-r--r-- | R/gensvm-kernels.R | 10 | ||||
| -rw-r--r-- | R/gensvm-package.R | 81 | ||||
| -rw-r--r-- | R/gensvm.R | 92 | ||||
| -rw-r--r-- | R/gensvm.accuracy.R | 37 | ||||
| -rw-r--r-- | R/gensvm.grid.R | 626 | ||||
| -rw-r--r-- | R/gensvm.maxabs.scale.R | 77 | ||||
| -rw-r--r-- | R/gensvm.refit.R | 82 | ||||
| -rw-r--r-- | R/gensvm.train.test.split.R | 121 | ||||
| -rw-r--r-- | R/plot.gensvm.R | 199 | ||||
| -rw-r--r-- | R/plot.gensvm.grid.R | 39 | ||||
| -rw-r--r-- | R/predict.gensvm.R | 67 | ||||
| -rw-r--r-- | R/predict.gensvm.grid.R | 47 | ||||
| -rw-r--r-- | R/print.gensvm.R | 74 | ||||
| -rw-r--r-- | R/print.gensvm.grid.R | 61 |
16 files changed, 1506 insertions, 145 deletions
diff --git a/R/coef.gensvm.R b/R/coef.gensvm.R index 19ab0aa..45eeb13 100644 --- a/R/coef.gensvm.R +++ b/R/coef.gensvm.R @@ -13,7 +13,7 @@ #' (n_{classes} - 1)} matrix formed by the remaining rows. #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -25,7 +25,11 @@ #' @export #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] #' +#' fit <- gensvm(x, y) +#' V <- coef(fit) #' coef.gensvm <- function(object, ...) { diff --git a/R/coef.gensvm.grid.R b/R/coef.gensvm.grid.R new file mode 100644 index 0000000..15e6525 --- /dev/null +++ b/R/coef.gensvm.grid.R @@ -0,0 +1,32 @@ +#' @title Get the parameter grid from a GenSVM Grid object +#' +#' @description Returns the parameter grid of a \code{gensvm.grid} object. +#' +#' @param object a \code{gensvm.grid} object +#' @param \dots further arguments are ignored +#' +#' @return The parameter grid of the GenSVMGrid object as a data frame. +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @method coef gensvm.grid +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' grid <- gensvm.grid(x, y) +#' pg <- coef(grid) +#' +coef.gensvm.grid <- function(object, ...) +{ + return(object$param.grid) +} diff --git a/R/gensvm-kernels.R b/R/gensvm-kernels.R deleted file mode 100644 index 8e445c0..0000000 --- a/R/gensvm-kernels.R +++ /dev/null @@ -1,10 +0,0 @@ -#' Kernels in GenSVM -#' -#' GenSVM can be used for both linear multiclass support vector machine -#' classification and for nonlinear classification with kernels. In general, -#' linear classification will be faster but depending on the dataset higher -#' classification performance can be achieved using a nonlinear kernel. -#' -#' The following nonlinear kernels are implemented in the GenSVM package: -#' \describe{ -#' \item{RBF}{The Radial Basis Function kernel is a commonly used kernel. diff --git a/R/gensvm-package.R b/R/gensvm-package.R index 13c2c31..f664577 100644 --- a/R/gensvm-package.R +++ b/R/gensvm-package.R @@ -1,15 +1,15 @@ #' GenSVM: A Generalized Multiclass Support Vector Machine #' #' The GenSVM classifier is a generalized multiclass support vector machine -#' (SVM). This classifier simultaneously aims to find decision boundaries that -#' separate the classes with as wide a margin as possible. In GenSVM, the loss -#' functions that measures how misclassifications are counted is very flexible. -#' This allows the user to tune the classifier to the dataset at hand and +#' (SVM). This classifier aims to find decision boundaries that separate the +#' classes with as wide a margin as possible. In GenSVM, the loss functions +#' that measures how misclassifications are counted is very flexible. This +#' allows the user to tune the classifier to the dataset at hand and #' potentially obtain higher classification accuracy. Moreover, this #' flexibility means that GenSVM has a number of alternative multiclass SVMs as #' special cases. One of the other advantages of GenSVM is that it is trained -#' in the primal, allowing the use of warm starts during optimization. This -#' means that for common tasks such as cross validation or repeated model +#' in the primal space, allowing the use of warm starts during optimization. +#' This means that for common tasks such as cross validation or repeated model #' fitting, GenSVM can be trained very quickly. #' #' This package provides functions for training the GenSVM model either as a @@ -26,19 +26,71 @@ #' GenSVM.} #' } #' -#' Other available functions are: +#' For the GenSVM and GenSVMGrid models the following two functions are +#' available. When applied to a GenSVMGrid object, the function is applied to +#' the best GenSVM model. #' \describe{ #' \item{\code{\link{plot}}}{Plot the low-dimensional \emph{simplex} space -#' where the decision boundaries are fixed.} +#' where the decision boundaries are fixed (for problems with 3 classes).} #' \item{\code{\link{predict}}}{Predict the class labels of new data using the #' GenSVM model.} -#' \item{\code{\link{coef}}}{Get the coefficients of the GenSVM model} -#' \item{\code{\link{print}}}{Print a short description of the fitted GenSVM -#' model} #' } #' +#' Moreover, for the GenSVM and GenSVMGrid models a \code{coef} function is +#' defined: +#' \describe{ +#' \item{\code{\link{coef.gensvm}}}{Get the coefficients of the fitted GenSVM +#' model.} +#' \item{\code{\link{coef.gensvm.grid}}}{Get the parameter grid of the GenSVM +#' grid search.} +#' } +#' +#' The following utility functions are also included: +#' \describe{ +#' \item{\code{\link{gensvm.accuracy}}}{Compute the accuracy score between true +#' and predicted class labels} +#' \item{\code{\link{gensvm.maxabs.scale}}}{Scale each column of the dataset by +#' its maximum absolute value, preserving sparsity and mapping the data to [-1, +#' 1]} +#' \item{\code{\link{gensvm.train.test.split}}}{Split a dataset into a training +#' and testing sample} +#' \item{\code{\link{gensvm.refit}}}{Refit a fitted GenSVM model with slightly +#' different parameters or on a different dataset} +#' } +#' +#' @section Kernels in GenSVM: +#' +#' GenSVM can be used for both linear and nonlinear multiclass support vector +#' machine classification. In general, linear classification will be faster but +#' depending on the dataset higher classification performance can be achieved +#' using a nonlinear kernel. +#' +#' The following nonlinear kernels are implemented in the GenSVM package: +#' \describe{ +#' \item{RBF}{The Radial Basis Function kernel is a well-known kernel function +#' based on the Euclidean distance between objects. It is defined as +#' \deqn{ +#' k(x_i, x_j) = exp( -\gamma || x_i - x_j ||^2 ) +#' } +#' } +#' \item{Polynomial}{A polynomial kernel can also be used in GenSVM. This +#' kernel function is implemented very generally and therefore takes three +#' parameters (\code{coef}, \code{gamma}, and \code{degree}). It is defined +#' as: +#' \deqn{ +#' k(x_i, x_j) = ( \gamma x_i' x_j + coef)^{degree} +#' } +#' } +#' \item{Sigmoid}{The sigmoid kernel is the final kernel implemented in +#' GenSVM. This kernel has two parameters and is implemented as follows: +#' \deqn{ +#' k(x_i, x_j) = \tanh( \gamma x_i' x_j + coef) +#' } +#' } +#' } +#' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -46,11 +98,10 @@ #' Multiclass Support Vector Machine}, Journal of Machine Learning Research, #' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. #' -#' @examples -#' +#' @aliases +#' gensvm.package #' #' @name gensvm-package #' @docType package -#' @import NULL #>NULL @@ -1,7 +1,8 @@ #' @title Fit the GenSVM model #' #' @description Fits the Generalized Multiclass Support Vector Machine model -#' with the given parameters. +#' with the given parameters. See the package documentation +#' (\code{\link{gensvm-package}}) for more general information about GenSVM. #' #' @param X data matrix with the predictors #' @param y class labels @@ -12,7 +13,8 @@ #' @param weights type of instance weights to use. Options are 'unit' for unit #' weights and 'group' for group size correction weight (eq. 4 in the paper). #' @param kernel the kernel type to use in the classifier. It must be one of -#' 'linear', 'poly', 'rbf', or 'sigmoid'. +#' 'linear', 'poly', 'rbf', or 'sigmoid'. See the section "Kernels in GenSVM" +#' in \code{\link{gensvm-package}} for more info. #' @param gamma kernel parameter for the rbf, polynomial, and sigmoid kernel. #' If gamma is 'auto', then 1/n_features will be used. #' @param coef parameter for the polynomial and sigmoid kernel. @@ -25,25 +27,45 @@ #' @param random.seed Seed for the random number generator (useful for #' reproducible output) #' @param max.iter Maximum number of iterations of the optimization algorithm. +#' @param seed.V Matrix to warm-start the optimization algorithm. This is +#' typically the output of \code{coef(fit)}. Note that this function will +#' silently drop seed.V if the dimensions don't match the provided data. #' #' @return A "gensvm" S3 object is returned for which the print, predict, coef, #' and plot methods are available. It has the following items: #' \item{call}{The call that was used to construct the model.} +#' \item{p}{The value of the lp norm in the loss function} #' \item{lambda}{The regularization parameter used in the model.} #' \item{kappa}{The hinge function parameter used.} #' \item{epsilon}{The stopping criterion used.} #' \item{weights}{The instance weights type used.} #' \item{kernel}{The kernel function used.} -#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}. +#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable} #' \item{coef}{The value of the coef parameter of the kernel, if applicable} #' \item{degree}{The degree of the kernel, if applicable} #' \item{kernel.eigen.cutoff}{The cutoff value of the reduced -#' eigendecomposition of the kernel matrix} +#' eigendecomposition of the kernel matrix.} +#' \item{verbose}{Whether or not the model was fitted with progress output} #' \item{random.seed}{The random seed used to seed the model.} #' \item{max.iter}{Maximum number of iterations of the algorithm.} +#' \item{n.objects}{Number of objects in the dataset} +#' \item{n.features}{Number of features in the dataset} +#' \item{n.classes}{Number of classes in the dataset} +#' \item{classes}{Array with the actual class labels} +#' \item{V}{Coefficient matrix} +#' \item{n.iter}{Number of iterations performed in training} +#' \item{n.support}{Number of support vectors in the final model} +#' \item{training.time}{Total training time} +#' \item{X.train}{When training with nonlinear kernels, the training data is +#' needed to perform prediction. For these kernels it is therefore stored in +#' the fitted model.} +#' +#' @note +#' This function returns partial results when the computation is interrupted by +#' the user. #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -59,10 +81,37 @@ #' @useDynLib gensvm_wrapper, .registration = TRUE #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # fit using the default parameters +#' fit <- gensvm(x, y) +#' +#' # fit and show progress +#' fit <- gensvm(x, y, verbose=T) +#' +#' # fit with some changed parameters +#' fit <- gensvm(x, y, lambda=1e-8) +#' +#' # Early stopping defined through epsilon +#' fit <- gensvm(x, y, epsilon=1e-3) +#' +#' # Early stopping defined through max.iter +#' fit <- gensvm(x, y, max.iter=1000) #' -gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, - weights='unit', kernel='linear', gamma='auto', coef=0.0, - degree=2.0, kernel.eigen.cutoff=1e-8, verbose=0, +#' # Nonlinear training +#' fit <- gensvm(x, y, kernel='rbf') +#' fit <- gensvm(x, y, kernel='poly', degree=2, gamma=1.0) +#' +#' # Setting the random seed and comparing results +#' fit <- gensvm(x, y, random.seed=123) +#' fit2 <- gensvm(x, y, random.seed=123) +#' all.equal(coef(fit), coef(fit2)) +#' +#' +gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, + weights='unit', kernel='linear', gamma='auto', coef=1.0, + degree=2.0, kernel.eigen.cutoff=1e-8, verbose=FALSE, random.seed=NULL, max.iter=1e8, seed.V=NULL) { call <- match.call() @@ -72,9 +121,6 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, if (is.null(random.seed)) random.seed <- runif(1) * (2**31 - 1) - # TODO: Store a labelencoder in the object, preferably as a partially - # hidden item. This can then be used with prediction. - n.objects <- nrow(X) n.features <- ncol(X) n.classes <- length(unique(y)) @@ -90,17 +136,23 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, # Convert weights to index weight.idx <- which(c("unit", "group") == weights) if (length(weight.idx) == 0) { - stop("Incorrect weight specification. ", + cat("Error: Incorrect weight specification. ", "Valid options are 'unit' and 'group'") + return } # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 if (length(kernel.idx) == 0) { - stop("Incorrect kernel specification. ", + cat("Error: Incorrect kernel specification. ", "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") + return } + seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V) + seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V) + + # Call the C train routine out <- .Call("R_gensvm_train", as.matrix(X), as.integer(y.clean), @@ -118,18 +170,24 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6, as.integer(max.iter), as.integer(random.seed), seed.V, + as.integer(seed.rows), + as.integer(seed.cols), as.integer(n.objects), as.integer(n.features), as.integer(n.classes)) + # build the output object object <- list(call = call, p = p, lambda = lambda, kappa = kappa, epsilon = epsilon, weights = weights, kernel = kernel, gamma = gamma, coef = coef, degree = degree, kernel.eigen.cutoff = kernel.eigen.cutoff, - random.seed = random.seed, max.iter = max.iter, - n.objects = n.objects, n.features = n.features, - n.classes = n.classes, classes = classes, V = out$V, - n.iter = out$n.iter, n.support = out$n.support) + verbose = verbose, random.seed = random.seed, + max.iter = max.iter, n.objects = n.objects, + n.features = n.features, n.classes = n.classes, + classes = classes, V = out$V, n.iter = out$n.iter, + n.support = out$n.support, + training.time = out$training.time, + X.train = if(kernel == 'linear') NULL else X) class(object) <- "gensvm" return(object) diff --git a/R/gensvm.accuracy.R b/R/gensvm.accuracy.R new file mode 100644 index 0000000..dbcd3cc --- /dev/null +++ b/R/gensvm.accuracy.R @@ -0,0 +1,37 @@ +#' @title Compute the accuracy score +#' +#' @param y.true vector of true labels +#' @param y.pred vector of predicted labels +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @seealso +#' \code{\link{predict.gensvm.grid}} +#' +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' fit <- gensvm(x, y) +#' gensvm.accuracy(predict(fit, x), y) +#' +gensvm.accuracy <- function(y.true, y.pred) +{ + n <- length(y.true) + if (n != length(y.pred)) { + cat("Error: Can't compute accuracy if vector don't have the ", + "same length\n") + return + } + + return (sum(y.true == y.pred) / n) +} diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R index 37e2f7f..5d27fde 100644 --- a/R/gensvm.grid.R +++ b/R/gensvm.grid.R @@ -3,44 +3,17 @@ #' @description This function performs a cross-validated grid search of the #' model parameters to find the best hyperparameter configuration for a given #' dataset. This function takes advantage of GenSVM's ability to use warm -#' starts to speed up computation. The function also uses the GenSVM C library -#' for speed. -#' -#' There are two ways to use this function: either by providing a data frame -#' with the parameter configurations to try or by giving each of the function -#' inputs a vector of values to evaluate. In the latter case all combinations -#' of the provided values will be used (i.e. the product set). +#' starts to speed up computation. The function uses the GenSVM C library for +#' speed. #' #' @param X training data matrix. We denote the size of this matrix by #' n_samples x n_features. #' @param y training vector of class labes of length n_samples. The number of #' unique labels in this vector is denoted by n_classes. -#' @param df Data frame with parameter configurations to evaluate. -#' If this is provided it overrides the other parameter ranges provided. The -#' data frame must provide *all* required columns, as described below. -#' @param p vector of values to try for the \eqn{p} hyperparameter -#' for the \eqn{\ell_p} norm in the loss function. All values should be on the -#' interval [1.0, 2.0]. -#' @param lambda vector of values for the regularization parameter -#' \eqn{\lambda} in the loss function. All values should be larger than 0. -#' @param kappa vector of values for the hinge function parameter in -#' the loss function. All values should be larger than -1. -#' @param weights vector of values for the instance weights. Values -#' should be either 'unit', 'group', or both. -#' @param kernel vector of values for the kernel type. Possible -#' values are: 'linear', 'rbf', 'poly', or 'sigmoid', or any combination of -#' these values. See the article \link[=gensvm-kernels]{Kernels in GenSVM} for -#' more information. -#' @param gamma kernel parameter for the 'rbf', 'poly', and 'sigmoid' kernels. -#' If it is 'auto', 1/n_features will be used. See the article -#' \link[=gensvm-kernels]{Kernels in GenSVM} for more information. -#' @param coef kernel parameter for the 'poly' and 'sigmoid' -#' kernels. See the article \link[=gensvm-kernels]{Kernels in GenSVM} for more -#' information. -#' @param degree kernel parameter for the 'poly' kernel. See the -#' article \link[=gensvm-kernels]{Kernels in GenSVM} for more information. -#' @param max.iter maximum number of iterations to run in the -#' optimization algorithm. +#' @param param.grid String (\code{'tiny'}, \code{'small'}, or \code{'full'}) +#' or data frame with parameter configurations to evaluate. Typically this is +#' the output of \code{expand.grid}. For more details, see "Using a Parameter +#' Grid" below. #' @param refit boolean variable. If true, the best model from cross validation #' is fitted again on the entire dataset. #' @param scoring metric to use to evaluate the classifier performance during @@ -49,29 +22,94 @@ #' values are better. If it is NULL, the accuracy score will be used. #' @param cv the number of cross-validation folds to use or a vector with the #' same length as \code{y} where each unique value denotes a test split. -#' @param verbose boolean variable to indicate whether training details should -#' be printed. +#' @param verbose integer to indicate the level of verbosity (higher is more +#' verbose) +#' @param return.train.score whether or not to return the scores on the +#' training splits #' #' @return A "gensvm.grid" S3 object with the following items: +#' \item{call}{Call that produced this object} +#' \item{param.grid}{Sorted version of the parameter grid used in training} #' \item{cv.results}{A data frame with the cross validation results} #' \item{best.estimator}{If refit=TRUE, this is the GenSVM model fitted with #' the best hyperparameter configuration, otherwise it is NULL} -#' \item{best.score}{Mean cross-validated score for the model with the best -#' hyperparameter configuration} +#' \item{best.score}{Mean cross-validated test score for the model with the +#' best hyperparameter configuration} #' \item{best.params}{Parameter configuration that provided the highest mean -#' cross-validated score} +#' cross-validated test score} #' \item{best.index}{Row index of the cv.results data frame that corresponds to #' the best hyperparameter configuration} #' \item{n.splits}{The number of cross-validation splits} +#' \item{n.objects}{The number of instances in the data} +#' \item{n.features}{The number of features of the data} +#' \item{n.classes}{The number of classes in the data} +#' \item{classes}{Array with the unique classes in the data} +#' \item{total.time}{Training time for the grid search} +#' \item{cv.idx}{Array with cross validation indices used to split the data} +#' +#' @section Using a Parameter Grid: +#' To evaluate certain paramater configurations, a data frame can be supplied +#' to the \code{param.grid} argument of the function. Such a data frame can +#' easily be generated using the R function \code{expand.grid}, or could be +#' created through other ways to test specific parameter configurations. +#' +#' Three parameter grids are predefined: +#' \describe{ +#' \item{\code{'tiny'}}{This parameter grid is generated by the function +#' \code{\link{gensvm.load.tiny.grid}} and is the default parameter grid. It +#' consists of parameter configurations that are likely to perform well on +#' various datasets.} +#' \item{\code{'small'}}{This grid is generated by +#' \code{\link{gensvm.load.small.grid}} and generates a data frame with 90 +#' configurations. It is typically fast to train but contains some +#' configurations that are unlikely to perform well. It is included for +#' educational purposes.} +#' \item{\code{'full'}}{This grid loads the parameter grid as used in the +#' GenSVM paper. It consists of 342 configurations and is generated by the +#' \code{\link{gensvm.load.full.grid}} function. Note that in the GenSVM paper +#' cross validation was done with this parameter grid, but the final training +#' step used \code{epsilon=1e-8}. The \code{\link{gensvm.refit}} function is +#' useful in this scenario.} +#' } #' +#' When you provide your own parameter grid, beware that only certain column +#' names are allowed in the data frame corresponding to parameters for the +#' GenSVM model. These names are: #' +#' \describe{ +#' \item{p}{Parameter for the lp norm. Must be in [1.0, 2.0].} +#' \item{kappa}{Parameter for the Huber hinge function. Must be larger than +#' -1.} +#' \item{lambda}{Parameter for the regularization term. Must be larger than 0.} +#' \item{weight}{Instance weight specification. Allowed values are "unit" for +#' unit weights and "group" for group-size correction weights} +#' \item{epsilon}{Stopping parameter for the algorithm. Must be larger than 0.} +#' \item{max.iter}{Maximum number of iterations of the algorithm. Must be +#' larger than 0.} +#' \item{kernel}{The kernel to used, allowed values are "linear", "poly", +#' "rbf", and "sigmoid". The default is "linear"} +#' \item{coef}{Parameter for the "poly" and "sigmoid" kernels. See the section +#' "Kernels in GenSVM" in the code{ink{gensvm-package}} page for more info.} +#' \item{degree}{Parameter for the "poly" kernel. See the section "Kernels in +#' GenSVM" in the code{ink{gensvm-package}} page for more info.} +#' \item{gamma}{Parameter for the "poly", "rbf", and "sigmoid" kernels. See the +#' section "Kernels in GenSVM" in the code{ink{gensvm-package}} page for more +#' info.} +#' } #' -#' @section Using a DataFrame: -#' ... +#' For variables that are not present in the \code{param.grid} data frame the +#' default parameter values in the \code{\link{gensvm}} function will be used. #' +#' Note that this function reorders the parameter grid to make the warm starts +#' as efficient as possible, which is why the param.grid in the result will not +#' be the same as the param.grid in the input. +#' +#' @note +#' This function returns partial results when the computation is interrupted by +#' the user. #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -80,37 +118,499 @@ #' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. #' #' @seealso -#' \code{\link{coef}}, \code{\link{print}}, \code{\link{predict}}, -#' \code{\link{plot}}, and \code{\link{gensvm.grid}}. -#' +#' \code{\link{predict.gensvm.grid}}, \code{\link{print.gensvm.grid}}, and +#' \code{\link{gensvm}}. #' #' @export #' #' @examples -#' X <- -#' - -gensvm.grid <- function(X, y, - df=NULL, - p=c(1.0, 1.5, 2.0), - lambda=c(1e-8, 1e-6, 1e-4, 1e-2, 1), - kappa=c(-0.9, 0.5, 5.0), - weights=c('unit', 'group'), - kernel=c('linear'), - gamma=c('auto'), - coef=c(0.0), - degree=c(2.0), - max.iter=c(1e8), - refit=TRUE, - scoring=NULL, - cv=3, - verbose=TRUE) +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # use the default parameter grid +#' grid <- gensvm.grid(x, y) +#' +#' # use a smaller parameter grid +#' pg <- expand.grid(p=c(1.0, 1.5, 2.0), kappa=c(-0.9, 1.0), epsilon=c(1e-3)) +#' grid <- gensvm.grid(x, y, param.grid=pg) +#' +#' # print the result +#' print(grid) +#' +#' # Using a custom scoring function (accuracy as percentage) +#' acc.pct <- function(yt, yp) { return (100 * sum(yt == yp) / length(yt)) } +#' grid <- gensvm.grid(x, y, scoring=acc.pct) +#' +gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3, + verbose=0, return.train.score=TRUE) { call <- match.call() + n.objects <- nrow(X) + n.features <- ncol(X) + n.classes <- length(unique(y)) + + if (is.character(param.grid)) { + if (param.grid == 'tiny') { + param.grid <- gensvm.load.tiny.grid() + } else if (param.grid == 'small') { + param.grid <- gensvm.load.small.grid() + } else if (param.grid == 'full') { + param.grid <- gensvm.load.full.grid() + } + } + + # Validate the range of the values for the gridsearch + gensvm.validate.param.grid(param.grid) + + # Sort the parameter grid for efficient warm starts + param.grid <- gensvm.sort.param.grid(param.grid) + + # Expand and convert the parameter grid for use in the C function + C.param.grid <- gensvm.expand.param.grid(param.grid, n.features) + + # Convert labels to integers + classes <- sort(unique(y)) + y.clean <- match(y, classes) + + if (is.vector(cv) && length(cv) == n.objects) { + folds <- sort(unique(cv)) + cv.idx <- match(cv, folds) - 1 + n.splits <- length(folds) + } else { + cv.idx <- gensvm.generate.cv.idx(n.objects, cv[1]) + n.splits <- cv + } + + results <- .Call("R_gensvm_grid", + as.matrix(X), + as.integer(y.clean), + as.matrix(C.param.grid), + as.integer(nrow(C.param.grid)), + as.integer(ncol(C.param.grid)), + as.integer(cv.idx), + as.integer(n.splits), + as.integer(verbose), + as.integer(n.objects), + as.integer(n.features), + as.integer(n.classes) + ) + + cv.results <- gensvm.cv.results(results, param.grid, cv.idx, + y.clean, scoring, + return.train.score=return.train.score) + best.index <- which.min(cv.results$rank.test.score)[1] + if (!is.na(best.index)) { # can occur when user interrupts + best.score <- cv.results$mean.test.score[best.index] + best.params <- param.grid[best.index, , drop=F] + # Remove duplicate attributes from best.params + attr(best.params, "out.attrs") <- NULL + } else { + best.score <- NA + best.params <- list() + } - object <- list(...) + if (refit && !is.na(best.index)) { + gensvm.args <- as.list(best.params) + gensvm.args$X <- X + gensvm.args$y <- y + best.estimator <- do.call(gensvm, gensvm.args) + } else { + best.estimator <- NULL + } + + object <- list(call = call, param.grid = param.grid, + cv.results = cv.results, best.estimator = best.estimator, + best.score = best.score, best.params = best.params, + best.index = best.index, n.splits = n.splits, + n.objects = n.objects, n.features = n.features, + n.classes = n.classes, classes = classes, + total.time = results$total.time, cv.idx = cv.idx) class(object) <- "gensvm.grid" return(object) } + +#' @title Load a tiny parameter grid for the GenSVM grid search +#' +#' @description This function returns a parameter grid to use in the GenSVM +#' grid search. This grid was obtained by analyzing the experiments done for +#' the GenSVM paper and selecting the configurations that achieve accuracy +#' within the 95th percentile on over 90% of the datasets. It is a good start +#' for a parameter search with a reasonably high chance of achieving good +#' performance on most datasets. +#' +#' Note that this grid is only tested to work well in combination with the +#' linear kernel. +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @seealso +#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.small.grid}}, +#' \code{\link{gensvm.load.full.grid}}. +#' +gensvm.load.tiny.grid <- function() +{ + df <- data.frame( + p=c(2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 2.0, 2.0), + kappa=c(5.0, 5.0, 0.5, 5.0, -0.9, 5.0, 0.5, -0.9, 0.5, 0.5), + lambda = c(2^-16, 2^-18, 2^-18, 2^-18, 2^-18, 2^-14, 2^-18, + 2^-18, 2^-16, 2^-16), + weight = c('unit', 'unit', 'unit', 'group', 'unit', + 'unit', 'group', 'unit', 'unit', 'group') + ) + return(df) +} + +#' @title Load a large parameter grid for the GenSVM grid search +#' +#' @description This loads the parameter grid from the GenSVM paper. It +#' consists of 342 configurations and is constructed from all possible +#' combinations of the following parameter sets: +#' +#' \code{p = c(1.0, 1.5, 2.0)} +#' +#' \code{lambda = 2^seq(-18, 18, 2)} +#' +#' \code{kappa = c(-0.9, 0.5, 5.0)} +#' +#' \code{weight = c('unit', 'group')} +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @seealso +#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.tiny.grid}}, +#' \code{\link{gensvm.load.full.grid}}. +#' +gensvm.load.full.grid <- function() +{ + df <- expand.grid(p=c(1.0, 1.5, 2.0), lambda=2^seq(-18, 18, 2), + kappa=c(-0.9, 0.5, 5.0), weight=c('unit', 'group'), + epsilon=c(1e-6)) + return(df) +} + + +#' @title Load the default parameter grid for the GenSVM grid search +#' +#' @description This function loads a default parameter grid to use for the +#' GenSVM gridsearch. It contains all possible combinations of the following +#' parameter sets: +#' +#' \code{p = c(1.0, 1.5, 2.0)} +#' +#' \code{lambda = c(1e-8, 1e-6, 1e-4, 1e-2, 1)} +#' +#' \code{kappa = c(-0.9, 0.5, 5.0)} +#' +#' \code{weight = c('unit', 'group')} +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @seealso +#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.tiny.grid}}, +#' \code{\link{gensvm.load.small.grid}}. +#' +gensvm.load.small.grid <- function() +{ + df <- expand.grid(p=c(1.0, 1.5, 2.0), lambda=c(1e-8, 1e-6, 1e-4, 1e-2, 1), + kappa=c(-0.9, 0.5, 5.0), weight=c('unit', 'group')) + return(df) +} + + +#' Generate a vector of cross-validation indices +#' +#' This function generates a vector of length \code{n} with values from 0 to +#' \code{folds-1} to mark train and test splits. +#' +gensvm.generate.cv.idx <- function(n, folds) +{ + cv.idx <- matrix(0, n, 1) + + big.folds <- n %% folds + small.fold.size <- n %/% folds + + j <- 0 + for (i in 0:(small.fold.size * folds)) { + while (TRUE) { + idx <- round(runif(1, 1, n)) + if (cv.idx[idx] == 0) { + cv.idx[idx] <- j + j <- j + 1 + j <- (j %% folds) + break + } + } + } + + j <- 1 + i <- 0 + while (i < big.folds) { + if (cv.idx[j] == 0) { + cv.idx[j] <- i + i <- i + 1 + } + j <- j + 1 + } + + return(cv.idx) +} + +gensvm.validate.param.grid <- function(df) +{ + expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight", + "kappa", "lambda", "p", "epsilon", "max.iter") + for (name in colnames(df)) { + if (!(name %in% expected.colnames)) { + stop("Invalid header name supplied in parameter grid: ", name) + } + } + + conditions <- list( + p=function(x) { x >= 1.0 && x <= 2.0 }, + kappa=function(x) { x > -1.0 }, + lambda=function(x) {x > 0.0 }, + epsilon=function(x) { x > 0.0 }, + gamma=function(x) { x != 0.0 }, + weight=function(x) { x %in% c("unit", "group") }, + kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } + ) + + for (idx in 1:nrow(df)) { + for (param in colnames(df)) { + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- df[[param]][idx] + if (!func(value)) + stop("Invalid value in grid for parameter: ", param) + } + } +} + +gensvm.cv.results <- function(results, param.grid, cv.idx, y.true, + scoring, return.train.score=TRUE) +{ + n.candidates <- nrow(param.grid) + n.splits <- length(unique(cv.idx)) + + score <- if(is.function(scoring)) scoring else gensvm.accuracy + + # Build names and initialize the data.frame + names <- c("mean.fit.time", "mean.score.time", "mean.test.score") + if (return.train.score) + names <- c(names, "mean.train.score") + for (param in names(param.grid)) { + names <- c(names, sprintf("param.%s", param)) + } + names <- c(names, "rank.test.score") + for (idx in sort(unique(cv.idx))) { + names <- c(names, sprintf("split%i.test.score", idx)) + if (return.train.score) + names <- c(names, sprintf("split%i.train.score", idx)) + } + names <- c(names, "std.fit.time", "std.score.time", "std.test.score") + if (return.train.score) + names <- c(names, "std.train.score") + + df <- data.frame(matrix(ncol=length(names), nrow=n.candidates)) + colnames(df) <- names + + for (pidx in 1:n.candidates) { + param <- param.grid[pidx, , drop=F] + durations <- results$durations[pidx, ] + predictions <- results$predictions[pidx, ] + + fit.times <- durations + score.times <- c() + test.scores <- c() + train.scores <- c() + + is.missing <- any(is.na(durations)) + + for (test.idx in sort(unique(cv.idx))) { + score.time <- 0 + + if (return.train.score) { + y.train.pred <- predictions[cv.idx != test.idx] + y.train.true <- y.true[cv.idx != test.idx] + + start.time <- proc.time() + train.score <- score(y.train.true, y.train.pred) + stop.time <- proc.time() + score.time <- score.time + (stop.time - start.time)[3] + + train.scores <- c(train.scores, train.score) + } + + y.test.pred <- predictions[cv.idx == test.idx] + y.test.true <- y.true[cv.idx == test.idx] + + start.time <- proc.time() + test.score <- score(y.test.true, y.test.pred) + stop.time <- proc.time() + score.time <- score.time + (stop.time - start.time)[3] + + test.scores <- c(test.scores, test.score) + + score.times <- c(score.times, score.time) + } + + df$mean.fit.time[pidx] <- mean(fit.times) + df$mean.score.time[pidx] <- if(is.missing) NA else mean(score.times) + df$mean.test.score[pidx] <- mean(test.scores) + df$std.fit.time[pidx] <- sd(fit.times) + df$std.score.time[pidx] <- if(is.missing) NA else sd(score.times) + df$std.test.score[pidx] <- sd(test.scores) + if (return.train.score) { + df$mean.train.score[pidx] <- mean(train.scores) + df$std.train.score[pidx] <- sd(train.scores) + } + + for (parname in names(param.grid)) { + df[[sprintf("param.%s", parname)]][pidx] <- param[[parname]] + } + + j <- 1 + for (test.idx in sort(unique(cv.idx))) { + lbl <- sprintf("split%i.test.score", test.idx) + df[[lbl]][pidx] <- test.scores[j] + if (return.train.score) { + lbl <- sprintf("split%i.train.score", test.idx) + df[[lbl]][pidx] <- train.scores[j] + } + j <- j + 1 + } + } + + df$rank.test.score <- gensvm.rank.score(df$mean.test.score) + + return(df) +} + +gensvm.sort.param.grid <- function(param.grid) +{ + all.cols <- c("kernel", "coef", "degree", "gamma", "weight", "kappa", + "lambda", "p", "epsilon", "max.iter") + + order.args <- NULL + for (name in all.cols) { + if (name %in% colnames(param.grid)) { + if (name == "epsilon") + order.args <- cbind(order.args, -param.grid[[name]]) + else + order.args <- cbind(order.args, param.grid[[name]]) + } + } + sorted.pg <- param.grid[do.call(order, as.list(as.data.frame(order.args))), ] + + rownames(sorted.pg) <- NULL + + return(sorted.pg) +} + +gensvm.expand.param.grid <- function(pg, n.features) +{ + if ("kernel" %in% colnames(pg)) { + all.kernels <- c("linear", "poly", "rbf", "sigmoid") + pg$kernel <- match(pg$kernel, all.kernels) - 1 + } else { + pg$kernel <- 0 + } + + if ("weight" %in% colnames(pg)) { + all.weights <- c("unit", "group") + pg$weight <- match(pg$weight, all.weights) + } else { + pg$weight <- 1 + } + + if ("gamma" %in% colnames(pg)) { + pg$gamma[pg$gamma == "auto"] <- 1.0/n.features + } else { + pg$gamma <- 1.0/n.features + } + + if (!("degree" %in% colnames(pg))) + pg$degree <- 2.0 + if (!("coef" %in% colnames(pg))) + pg$coef <- 0.0 + if (!("p" %in% colnames(pg))) + pg$p <- 1.0 + if (!("lambda" %in% colnames(pg))) + pg$lambda <- 1e-8 + if (!("kappa" %in% colnames(pg))) + pg$kappa <- 0.0 + if (!("epsilon" %in% colnames(pg))) + pg$epsilon <- 1e-6 + if (!("max.iter" %in% colnames(pg))) + pg$max.iter <- 1e8 + + C.param.grid <- data.frame(kernel=pg$kernel, coef=pg$coef, + degree=pg$degree, gamma=pg$gamma, + weight=pg$weight, kappa=pg$kappa, + lambda=pg$lambda, p=pg$p, epsilon=pg$epsilon, + max.iter=pg$max.iter) + + return(C.param.grid) +} + +#' @title Compute the ranks for the numbers in a given vector +#' +#' @details +#' This function computes the ranks for the values in an array. The highest +#' value gets the smallest rank. Ties are broken by assigning the smallest +#' value. +#' +#' @param x array of numeric values +#' +#' @examples +#' x <- c(7, 0.1, 0.5, 0.1, 10, 100, 200) +#' gensvm.rank.score(x) +#' [ 4 6 5 6 3 2 1 ] +#' +gensvm.rank.score <- function(x) +{ + x <- as.array(x) + l <- length(x) + r <- 1 + ranks <- as.vector(matrix(0, l, 1)) + ranks[which(is.na(x))] <- NA + while (!all(mapply(is.na, x))) { + m <- max(x, na.rm=T) + idx <- which(x == m) + ranks[idx] <- r + r <- r + length(idx) + x[idx] <- NA + } + + return(ranks) +} diff --git a/R/gensvm.maxabs.scale.R b/R/gensvm.maxabs.scale.R new file mode 100644 index 0000000..6ac351b --- /dev/null +++ b/R/gensvm.maxabs.scale.R @@ -0,0 +1,77 @@ +#' @title Scale each column of a matrix by its maximum absolute value +#' +#' @description Scaling a dataset can creatly decrease the computation time of +#' GenSVM. This function scales the data by dividing each column of a matrix by +#' the maximum absolute value of that column. This preserves sparsity in the +#' data while mapping each column to the interval [-1, 1]. +#' +#' Optionally a test dataset can be provided as well. In this case, the scaling +#' will be computed on the first argument (\code{x}) and applied to the test +#' dataset. Note that the return value is a list when this argument is +#' supplied. +#' +#' @param x a matrix to scale +#' @param x.test (optional) a test matrix to scale as well. +#' +#' @return if x.test=NULL a scaled matrix where the maximum value of the +#' columns is 1 and the minimum value of the columns isn't below -1. If x.test +#' is supplied, a list with elements \code{x} and \code{x.test} representing +#' the scaled datasets. +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' +#' # check the min and max of the columns +#' apply(x, 2, min) +#' apply(x, 2, max) +#' +#' # scale the data +#' x.scale <- gensvm.maxabs.scale(x) +#' +#' # check again (max should be 1.0, min shouldn't be below -1) +#' apply(x.scale, 2, min) +#' apply(x.scale, 2, max) +#' +#' # with a train and test dataset +#' x <- iris[, -5] +#' split <- gensvm.train.test.split(x) +#' x.train <- split$x.train +#' x.test <- split$x.test +#' scaled <- gensvm.maxabs.scale(x.train, x.test) +#' x.train.scl <- scaled$x +#' x.test.scl <- scaled$x.test +#' +gensvm.maxabs.scale <- function(x, x.test=NULL) +{ + xm <- as.matrix(x) + max.abs <- apply(apply(xm, 2, abs), 2, max) + max.abs[max.abs == 0] <- 1 + + scaled <- xm %*% diag(1.0 / max.abs) + colnames(scaled) <- colnames(x) + rownames(scaled) <- rownames(x) + + if (!is.null(x.test)) { + xtm <- as.matrix(x.test) + scaled.test <- xtm %*% diag(1.0 / max.abs) + colnames(scaled.test) <- colnames(x.test) + rownames(scaled.test) <- rownames(x.test) + + ret.val <- list(x=scaled, x.test=scaled.test) + } else { + ret.val <- scaled + } + + return(ret.val) +} diff --git a/R/gensvm.refit.R b/R/gensvm.refit.R new file mode 100644 index 0000000..a6af3fd --- /dev/null +++ b/R/gensvm.refit.R @@ -0,0 +1,82 @@ +#' @title Train an already fitted model on new data +#' +#' @title This function can be used to train an existing model on new data or +#' fit an existing model with slightly different parameters. It is useful for +#' retraining without having to copy all the parameters over. One common +#' application for this is to refit the best model found by a grid search, as +#' illustrated in the examples. +#' +#' @param fit Fitted \code{gensvm} object +#' @param X Data matrix of the new data +#' @param y Label vector of the new data +#' @param verbose Turn on verbose output and fit progress. If NULL (the +#' default) the value from the fitted model is chosen. +#' +#' @return a new fitted \code{gensvm} model +#' +#' @export +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # fit a standard model and refit with slightly different parameters +#' fit <- gensvm(x, y) +#' fit2 <- gensvm.refit(x, y, epsilon=1e-8) +#' +#' # refit a model returned by a grid search +#' grid <- gensvm.grid(x, y) +#' fit <- gensvm.refit(fit, x, y, epsilon=1e-8) +#' +#' # refit on different data +#' idx <- runif(nrow(x)) > 0.5 +#' x1 <- x[idx, ] +#' x2 <- x[!idx, ] +#' y1 <- y[idx] +#' y2 <- y[!idx] +#' +#' fit1 <- gensvm(x1, y1) +#' fit2 <- gensvm.refit(fit1, x2, y2) +#' +gensvm.refit <- function(fit, X, y, p=NULL, lambda=NULL, kappa=NULL, + epsilon=NULL, weights=NULL, kernel=NULL, gamma=NULL, + coef=NULL, degree=NULL, kernel.eigen.cutoff=NULL, + max.iter=NULL, verbose=NULL, random.seed=NULL) +{ + p <- if(is.null(p)) fit$p else p + lambda <- if(is.null(lambda)) fit$lambda else lambda + kappa <- if(is.null(kappa)) fit$kappa else kappa + epsilon <- if(is.null(epsilon)) fit$epsilon else epsilon + weights <- if(is.null(weights)) fit$weights else weights + kernel <- if(is.null(kernel)) fit$kernel else kernel + gamma <- if(is.null(gamma)) fit$gamma else gamma + coef <- if(is.null(coef)) fit$coef else coef + degree <- if(is.null(degree)) fit$degree else degree + kernel.eigen.cutoff <- (if(is.null(kernel.eigen.cutoff)) + fit$kernel.eigen.cutoff else kernel.eigen.cutoff) + max.iter <- if(is.null(max.iter)) fit$max.iter else max.iter + verbose <- if(is.null(verbose)) fit$verbose else verbose + random.seed <- if(is.null(random.seed)) fit$random.seed else random.seed + + # Setting the error handler here is necessary in case the user interrupts + # this call to gensvm. If we don't set the error handler, R will + # unnecessarily drop to a browser() session. We reset the error handler + # after the call to gensvm(). + options(error=function() {}) + newfit <- gensvm(X, y, p=p, lambda=lambda, kappa=kappa, epsilon=epsilon, + weights=weights, kernel=kernel, gamma=gamma, coef=coef, + degree=degree, kernel.eigen.cutoff=kernel.eigen.cutoff, + verbose=verbose, max.iter=max.iter, seed.V=coef(fit)) + options(error=NULL) + + return(newfit) +} diff --git a/R/gensvm.train.test.split.R b/R/gensvm.train.test.split.R new file mode 100644 index 0000000..406f80e --- /dev/null +++ b/R/gensvm.train.test.split.R @@ -0,0 +1,121 @@ +#' @title Create a train/test split of a dataset +#' +#' @description Often it is desirable to split a dataset into a training and +#' testing sample. This function is included in GenSVM to make it easy to do +#' so. The function is inspired by a similar function in Scikit-Learn. +#' +#' @param x array to split +#' @param y another array to split (typically this is a vector) +#' @param train.size size of the training dataset. This can be provided as +#' float or as int. If it's a float, it should be between 0.0 and 1.0 and +#' represents the fraction of the dataset that should be placed in the training +#' dataset. If it's an int, it represents the exact number of samples in the +#' training dataset. If it is NULL, the complement of \code{test.size} will be +#' used. +#' @param test.size size of the test dataset. Similarly to train.size both a +#' float or an int can be supplied. If it's NULL, the complement of train.size +#' will be used. If both train.size and test.size are NULL, a default test.size +#' of 0.25 will be used. +#' @param shuffle shuffle the rows or not +#' @param random.state seed for the random number generator (int) +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # using the default values +#' split <- gensvm.train.test.split(x, y) +#' +#' # using the split in a GenSVM model +#' fit <- gensvm(split$x.train, split$y.train) +#' gensvm.accuracy(split$y.test, predict(fit, split$x.test)) +#' +#' # using attach makes the results directly available +#' attach(gensvm.train.test.split(x, y)) +#' fit <- gensvm(x.train, y.train) +#' gensvm.accuracy(y.test, predict(fit, x.test)) +#' +gensvm.train.test.split <- function(x, y=NULL, train.size=NULL, test.size=NULL, + shuffle=TRUE, random.state=NULL, + return.idx=FALSE) +{ + if (!is.null(y) && dim(as.matrix(x))[1] != dim(as.matrix(y))[1]) { + cat("Error: First dimension of x and y should be equal.\n") + return + } + + n.objects <- dim(as.matrix(x))[1] + + if (is.null(train.size) && is.null(test.size)) { + test.size <- round(0.25 * n.objects) + train.size <- n.objects - test.size + } + else if (is.null(train.size)) { + if (test.size > 0.0 && test.size < 1.0) + test.size <- round(n.objects * test.size) + train.size <- n.objects - test.size + } + else if (is.null(test.size)) { + if (train.size > 0.0 && train.size < 1.0) + train.size <- round(n.objects * train.size) + test.size <- n.objects - train.size + } + else { + if (train.size > 0.0 && train.size < 1.0) + train.size <- round(n.objects * train.size) + if (test.size > 0.0 && test.size < 1.0) + test.size <- round(n.objects * test.size) + } + + if (!is.null(random.state)) + set.seed(random.state) + + if (shuffle) { + train.idx <- sample(n.objects, train.size) + diff <- setdiff(1:n.objects, train.idx) + test.idx <- sample(diff, test.size) + } else { + train.idx <- 1:train.size + diff <- setdiff(1:n.objects, train.idx) + test.idx <- diff[1:test.size] + } + + x.train <- x[train.idx, ] + x.test <- x[test.idx, ] + + if (!is.null(y)) { + if (is.matrix(y)) { + y.train <- y[train.idx, ] + y.test <- y[test.idx, ] + } else { + y.train <- y[train.idx] + y.test <- y[test.idx] + } + } + + out <- list( + x.train = x.train, + x.test = x.test + ) + if (!is.null(y)) { + out$y.train <- y.train + out$y.test <- y.test + } + if (return.idx) { + out$idx.train <- train.idx + out$idx.test <- test.idx + } + + return(out) +} diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R new file mode 100644 index 0000000..0ce215b --- /dev/null +++ b/R/plot.gensvm.R @@ -0,0 +1,199 @@ +#' @title Plot the simplex space of the fitted GenSVM model +#' +#' @description This function creates a plot of the simplex space for a fitted +#' GenSVM model and the given data set, as long as the dataset consists of only +#' 3 classes. For more than 3 classes, the simplex space is too high +#' dimensional to easily visualize. +#' +#' @param fit A fitted \code{gensvm} object +#' @param x the dataset to plot +#' @param y.true the true data labels. If provided the objects will be colored +#' using the true labels instead of the predicted labels. This makes it easy to +#' identify misclassified objects. +#' @param with.margins plot the margins +#' @param with.shading show shaded areas for the class regions +#' @param with.legend show the legend for the class labels +#' @param center.plot ensure that the boundaries and margins are always visible +#' in the plot +#' @param ... further arguments are ignored +#' +#' @return returns the object passed as input +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @method plot gensvm +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # train the model +#' fit <- gensvm(x, y) +#' +#' # plot the simplex space +#' plot(fit, x) +#' +#' # plot and use the true colors (easier to spot misclassified samples) +#' plot(fit, x, y.true=y) +#' +#' # plot only misclassified samples +#' x.mis <- x[predict(fit, x) != y, ] +#' y.mis.true <- y[predict(fit, x) != y, ] +#' plot(fit, x.bad) +#' plot(fit, x.bad, y.true=y.mis.true) +#' +plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, + with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, + ...) +{ + if (fit$n.classes != 3) { + cat("Error: Can only plot with 3 classes\n") + return + } + + # Sanity check + if (ncol(x) != fit$n.features) { + cat("Error: Number of features of fitted model and testing data + disagree.\n") + return + } + + x.train <- fit$X.train + if (fit$kernel != 'linear' && is.null(x.train)) { + cat("Error: The training data is needed to plot data for ", + "nonlinear GenSVM. This data is not present in the fitted ", + "model!\n", sep="") + return + } + if (!is.null(x.train) && ncol(x.train) != fit$n.features) { + cat("Error: Number of features of fitted model and training data disagree.") + return + } + + x <- as.matrix(x) + + if (fit$kernel == 'linear') { + V <- coef(fit) + Z <- cbind(matrix(1, dim(x)[1], 1), x) + S <- Z %*% V + y.pred.orig <- predict(fit, x) + } else { + kernels <- c("linear", "poly", "rbf", "sigmoid") + kernel.idx <- which(kernels == fit$kernel) - 1 + plotdata <- .Call("R_gensvm_plotdata_kernels", + as.matrix(x), + as.matrix(x.train), + as.matrix(fit$V), + as.integer(nrow(fit$V)), + as.integer(ncol(fit$V)), + as.integer(nrow(x.train)), + as.integer(nrow(x)), + as.integer(fit$n.features), + as.integer(fit$n.classes), + as.integer(kernel.idx), + fit$gamma, + fit$coef, + fit$degree, + fit$kernel.eigen.cutoff + ) + S <- plotdata$ZV + y.pred.orig <- plotdata$y.pred + } + + classes <- fit$classes + if (is.factor(y.pred.orig)) { + y.pred <- match(y.pred.orig, classes) + } else { + y.pred <- y.pred.orig + } + + # Define some colors + point.blue <- rgb(31, 119, 180, maxColorValue=255) + point.orange <- rgb(255, 127, 14, maxColorValue=255) + point.green <- rgb(44, 160, 44, maxColorValue=255) + fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255) + fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255) + fill.green <- rgb(44, 160, 44, 51, maxColorValue=255) + + colors <- as.matrix(c(point.green, point.blue, point.orange)) + markers <- as.matrix(c(15, 16, 17)) + + if (is.null(y.true)) { + col.vector <- colors[y.pred] + mark.vector <- markers[y.pred] + } else { + col.vector <- colors[y.true] + mark.vector <- markers[y.true] + } + + par(pty="s") + if (center.plot) { + new.xlim <- c(min(min(S[, 1]), -1.2), max(max(S[, 1]), 1.2)) + new.ylim <- c(min(min(S[, 2]), -0.75), max(max(S[, 2]), 1.2)) + plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='', + asp=1, xlim=new.xlim, ylim=new.ylim) + } else { + plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='', + asp=1) + } + + limits <- par("usr") + xmin <- limits[1] + xmax <- limits[2] + ymin <- limits[3] + ymax <- limits[4] + + # draw the fixed boundaries + segments(0, 0, 0, ymin) + segments(0, 0, xmax, xmax/sqrt(3)) + segments(xmin, abs(xmin)/sqrt(3), 0, 0) + + if (with.margins) { + # margin from left below decision boundary to center + segments(xmin, -xmin/sqrt(3) - sqrt(4/3), -1, -1/sqrt(3), lty=2) + + # margin from left center to down + segments(-1, -1/sqrt(3), -1, ymin, lty=2) + + # margin from right center to middle + segments(1, -1/sqrt(3), 1, ymin, lty=2) + + # margin from right center to right boundary + segments(1, -1/sqrt(3), xmax, xmax/sqrt(3) - sqrt(4/3), lty=2) + + # margin from center to top left + segments(xmin, -xmin/sqrt(3) + sqrt(4/3), 0, sqrt(4/3), lty=2) + + # margin from center to top right + segments(0, sqrt(4/3), xmax, xmax/sqrt(3) + sqrt(4/3), lty=2) + } + + if (with.shading) { + # bottom left + polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) - + sqrt(4/3)), col=fill.green, border=NA) + # bottom right + polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3), + -1/sqrt(3)), col=fill.blue, border=NA) + # top + polygon(c(xmin, 0, xmax, xmax, xmin), + c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3), + ymax, ymax), col=fill.orange, + border=NA) + } + + if (with.legend) { + offset <- abs(xmax - xmin) * 0.05 + legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T) + } + + invisible(fit) +} diff --git a/R/plot.gensvm.grid.R b/R/plot.gensvm.grid.R new file mode 100644 index 0000000..da101e6 --- /dev/null +++ b/R/plot.gensvm.grid.R @@ -0,0 +1,39 @@ +#' @title Plot the simplex space of the best fitted model in the GenSVMGrid +#' +#' @description This is a wrapper which calls the plot function for the best +#' model in the provided GenSVMGrid object. See the documentation for +#' \code{\link{plot.gensvm}} for more information. +#' +#' @param grid A \code{gensvm.grid} object trained with refit=TRUE +#' @param x the dataset to plot +#' @param ... further arguments are passed to the plot function +#' +#' @return returns the object passed as input +#' +#' @export +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' grid <- gensvm.grid(x, y) +#' plot(grid, x) +#' +plot.gensvm.grid <- function(grid, x, ...) +{ + if (is.null(grid$best.estimator)) { + cat("Error: Can't plot, the best.estimator element is NULL\n") + return + } + fit <- grid$best.estimator + return(plot(fit, x, ...)) +} diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 5c8f2e7..7e04fe4 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -4,8 +4,8 @@ #' fitted GenSVM model. #' #' @param fit Fitted \code{gensvm} object -#' @param newx Matrix of new values for \code{x} for which predictions need to -#' be made. +#' @param x.test Matrix of new values for \code{x} for which predictions need +#' to be made. #' @param \dots further arguments are ignored #' #' @return a vector of class labels, with the same type as the original class @@ -15,7 +15,7 @@ #' @aliases predict #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -24,10 +24,20 @@ #' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] #' +#' # create a training and test sample +#' attach(gensvm.train.test.split(x, y)) +#' fit <- gensvm(x.train, y.train) #' +#' # predict the class labels of the test sample +#' y.test.pred <- predict(fit, x.test) #' -predict.gensvm <- function(fit, newx, ...) +#' # compute the accuracy with gensvm.accuracy +#' gensvm.accuracy(y.test, y.test.pred) +#' +predict.gensvm <- function(fit, x.test, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -37,16 +47,53 @@ predict.gensvm <- function(fit, newx, ...) ## the C implementation is *much* faster than doing it in R. # Sanity check - if (ncol(newx) != fit$n.features) - stop("Number of features of fitted model and supplied data disagree.") + if (ncol(x.test) != fit$n.features) { + cat("Error: Number of features of fitted model and testing", + "data disagree.\n") + return + } + + x.train <- fit$X.train + if (fit$kernel != 'linear' && is.null(x.train)) { + cat("Error: The training data is needed to compute predictions for ", + "nonlinear GenSVM. This data is not present in the fitted ", + "model!\n", sep="") + } + if (!is.null(x.train) && ncol(x.train) != fit$n.features) { + cat("Error: Number of features of fitted model and training", + "data disagree.\n") + return + } - y.pred.c <- .Call("R_gensvm_predict", - as.matrix(newx), + if (fit$kernel == 'linear') { + y.pred.c <- .Call("R_gensvm_predict", + as.matrix(x.test), as.matrix(fit$V), - as.integer(nrow(newx)), - as.integer(ncol(newx)), + as.integer(nrow(x.test)), + as.integer(ncol(x.test)), as.integer(fit$n.classes) ) + } else { + kernels <- c("linear", "poly", "rbf", "sigmoid") + kernel.idx <- which(kernels == fit$kernel) - 1 + y.pred.c <- .Call("R_gensvm_predict_kernels", + as.matrix(x.test), + as.matrix(x.train), + as.matrix(fit$V), + as.integer(nrow(fit$V)), + as.integer(ncol(fit$V)), + as.integer(nrow(x.train)), + as.integer(nrow(x.test)), + as.integer(fit$n.features), + as.integer(fit$n.classes), + as.integer(kernel.idx), + fit$gamma, + fit$coef, + fit$degree, + fit$kernel.eigen.cutoff + ) + } + yhat <- fit$classes[y.pred.c] return(yhat) diff --git a/R/predict.gensvm.grid.R b/R/predict.gensvm.grid.R new file mode 100644 index 0000000..81a0207 --- /dev/null +++ b/R/predict.gensvm.grid.R @@ -0,0 +1,47 @@ +#' @title Predict class labels from the GenSVMGrid class +#' +#' @description Predict class labels using the best model from a grid search. +#' After doing a grid search with the \code{\link{gensvm.grid}} function, this +#' function can be used to make predictions of class labels. It uses the best +#' GenSVM model found during the grid search to do the predictions. Note that +#' this model is only available if \code{refit=TRUE} was specified in the +#' \code{\link{gensvm.grid}} call (the default). +#' +#' @param grid A \code{gensvm.grid} object trained with \code{refit=TRUE} +#' @param newx Matrix of new values for \code{x} for which predictions need to +#' be computed. +#' @param \dots further arguments are passed to predict.gensvm() +#' +#' @return a vector of class labels, with the same type as the original class +#' labels provided to gensvm.grid() +#' +#' @export +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # run a grid search +#' grid <- gensvm.grid(x, y) +#' +#' # predict training sample +#' y.hat <- predict(grid, x) +#' +predict.gensvm.grid <- function(grid, newx, ...) +{ + if (is.null(grid$best.estimator)) { + cat("Error: Can't predict, the best.estimator element is NULL\n") + return + } + + return(predict(grid$best.estimator, newx, ...)) +} diff --git a/R/print.gensvm.R b/R/print.gensvm.R index 06a3649..119b264 100644 --- a/R/print.gensvm.R +++ b/R/print.gensvm.R @@ -2,13 +2,14 @@ #' #' @description Prints a short description of the fitted GenSVM model #' -#' @param object A \code{gensvm} object to print +#' @param fit A \code{gensvm} object to print #' @param \dots further arguments are ignored #' -#' @return returns the object passed as input +#' @return returns the object passed as input. This can be useful for chaining +#' operations on a fit object. #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -20,37 +21,52 @@ #' @export #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] #' +#' # fit and print the model +#' fit <- gensvm(x, y) +#' print(fit) #' -print.gensvm <- function(object, ...) +#' # (advanced) use the fact that print returns the fitted model +#' fit <- gensvm(x, y) +#' predict(print(fit), x) +#' +print.gensvm <- function(fit, ...) { - cat("\nCall:\n") - dput(object$call) - cat("\nData:\n") - cat("\tn.objects:", object$n.objects, "\n") - cat("\tn.features:", object$n.features, "\n") - cat("\tn.classes:", object$n.classes, "\n") - cat("\tclasses:", object$classes, "\n") + cat("Data:\n") + cat("\tn.objects:", fit$n.objects, "\n") + cat("\tn.features:", fit$n.features, "\n") + cat("\tn.classes:", fit$n.classes, "\n") + if (is.factor(fit$classes)) + cat("\tclasses:", levels(fit$classes), "\n") + else + cat("\tclasses:", fit$classes, "\n") cat("Parameters:\n") - cat("\tp:", object$p, "\n") - cat("\tlambda:", object$lambda, "\n") - cat("\tkappa:", object$kappa, "\n") - cat("\tepsilon:", object$epsilon, "\n") - cat("\tweights:", object$weights, "\n") - cat("\tmax.iter:", object$max.iter, "\n") - cat("\trandom.seed:", object$random.seed, "\n") - cat("\tkernel:", object$kernel, "\n") - if (object$kernel %in% c("poly", "rbf", "sigmoid")) { - cat("\tkernel.eigen.cutoff:", object$kernel.eigen.cutoff, "\n") - cat("\tgamma:", object$gamma, "\n") + cat("\tp:", fit$p, "\n") + cat("\tlambda:", fit$lambda, "\n") + cat("\tkappa:", fit$kappa, "\n") + cat("\tepsilon:", fit$epsilon, "\n") + cat("\tweights:", fit$weights, "\n") + cat("\tmax.iter:", fit$max.iter, "\n") + cat("\trandom.seed:", fit$random.seed, "\n") + if (is.factor(fit$kernel)) { + cat("\tkernel:", levels(fit$kernel)[as.numeric(fit$kernel)], "\n") + } else { + cat("\tkernel:", fit$kernel, "\n") + } + if (fit$kernel %in% c("poly", "rbf", "sigmoid")) { + cat("\tkernel.eigen.cutoff:", fit$kernel.eigen.cutoff, "\n") + cat("\tgamma:", fit$gamma, "\n") } - if (object$kernel %in% c("poly", "sigmoid")) - cat("\tcoef:", object$coef, "\n") - if (object$kernel == 'poly') - cat("\tdegree:", object$degree, "\n") + if (fit$kernel %in% c("poly", "sigmoid")) + cat("\tcoef:", fit$coef, "\n") + if (fit$kernel == 'poly') + cat("\tdegree:", fit$degree, "\n") cat("Results:\n") - cat("\tn.iter:", object$n.iter, "\n") - cat("\tn.support:", object$n.support, "\n") + cat("\ttime:", fit$training.time, "\n") + cat("\tn.iter:", fit$n.iter, "\n") + cat("\tn.support:", fit$n.support, "\n") - invisible(object) + invisible(fit) } diff --git a/R/print.gensvm.grid.R b/R/print.gensvm.grid.R new file mode 100644 index 0000000..88967d7 --- /dev/null +++ b/R/print.gensvm.grid.R @@ -0,0 +1,61 @@ +#' @title Print the fitted GenSVMGrid model +#' +#' @description Prints the summary of the fitted GenSVMGrid model +#' +#' @param grid a \code{gensvm.grid} object to print +#' @param \dots further arguments are ignored +#' +#' @return returns the object passed as input +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @method print gensvm.grid +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # fit a grid search and print the resulting object +#' grid <- gensvm.grid(x, y) +#' print(grid) +#' +print.gensvm.grid <- function(grid, ...) +{ + cat("Data:\n") + cat("\tn.objects:", grid$n.objects, "\n") + cat("\tn.features:", grid$n.features, "\n") + cat("\tn.classes:", grid$n.classes, "\n") + if (is.factor(grid$classes)) + cat("\tclasses:", levels(grid$classes), "\n") + else + cat("\tclasses:", grid$classes, "\n") + cat("Config:\n") + cat("\tNumber of cv splits:", grid$n.splits, "\n") + not.run <- sum(is.na(grid$cv.results$rank.test.score)) + if (not.run > 0) { + cat("\tParameter grid size:", dim(grid$param.grid)[1]) + cat(" (", not.run, " incomplete)", sep="") + cat("\n") + } else { + cat("\tParameter grid size:", dim(grid$param.grid)[1], "\n") + } + cat("Results:\n") + cat("\tTotal grid search time:", grid$total.time, "\n") + if (!is.na(grid$best.index)) { + best <- grid$cv.results[grid$best.index, ] + cat("\tBest mean test score:", best$mean.test.score, "\n") + cat("\tBest mean fit time:", best$mean.fit.time, "\n") + for (name in colnames(grid$best.params)) + cat("\tBest parameter", name, "=", grid$best.params[[name]], "\n") + } + + invisible(grid) +} |
