aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-27 12:31:28 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-27 12:31:28 +0100
commit004941896bac692d354c41a3334d20ee1d4627f7 (patch)
tree2b11e42d8524843409e2bf8deb4ceb74c8b69347 /R
parentupdates to GenSVM C library (diff)
downloadrgensvm-004941896bac692d354c41a3334d20ee1d4627f7.tar.gz
rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.zip
GenSVM R package
Diffstat (limited to 'R')
-rw-r--r--R/coef.gensvm.R6
-rw-r--r--R/coef.gensvm.grid.R32
-rw-r--r--R/gensvm-kernels.R10
-rw-r--r--R/gensvm-package.R81
-rw-r--r--R/gensvm.R92
-rw-r--r--R/gensvm.accuracy.R37
-rw-r--r--R/gensvm.grid.R626
-rw-r--r--R/gensvm.maxabs.scale.R77
-rw-r--r--R/gensvm.refit.R82
-rw-r--r--R/gensvm.train.test.split.R121
-rw-r--r--R/plot.gensvm.R199
-rw-r--r--R/plot.gensvm.grid.R39
-rw-r--r--R/predict.gensvm.R67
-rw-r--r--R/predict.gensvm.grid.R47
-rw-r--r--R/print.gensvm.R74
-rw-r--r--R/print.gensvm.grid.R61
16 files changed, 1506 insertions, 145 deletions
diff --git a/R/coef.gensvm.R b/R/coef.gensvm.R
index 19ab0aa..45eeb13 100644
--- a/R/coef.gensvm.R
+++ b/R/coef.gensvm.R
@@ -13,7 +13,7 @@
#' (n_{classes} - 1)} matrix formed by the remaining rows.
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -25,7 +25,11 @@
#' @export
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
#'
+#' fit <- gensvm(x, y)
+#' V <- coef(fit)
#'
coef.gensvm <- function(object, ...)
{
diff --git a/R/coef.gensvm.grid.R b/R/coef.gensvm.grid.R
new file mode 100644
index 0000000..15e6525
--- /dev/null
+++ b/R/coef.gensvm.grid.R
@@ -0,0 +1,32 @@
+#' @title Get the parameter grid from a GenSVM Grid object
+#'
+#' @description Returns the parameter grid of a \code{gensvm.grid} object.
+#'
+#' @param object a \code{gensvm.grid} object
+#' @param \dots further arguments are ignored
+#'
+#' @return The parameter grid of the GenSVMGrid object as a data frame.
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @method coef gensvm.grid
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' grid <- gensvm.grid(x, y)
+#' pg <- coef(grid)
+#'
+coef.gensvm.grid <- function(object, ...)
+{
+ return(object$param.grid)
+}
diff --git a/R/gensvm-kernels.R b/R/gensvm-kernels.R
deleted file mode 100644
index 8e445c0..0000000
--- a/R/gensvm-kernels.R
+++ /dev/null
@@ -1,10 +0,0 @@
-#' Kernels in GenSVM
-#'
-#' GenSVM can be used for both linear multiclass support vector machine
-#' classification and for nonlinear classification with kernels. In general,
-#' linear classification will be faster but depending on the dataset higher
-#' classification performance can be achieved using a nonlinear kernel.
-#'
-#' The following nonlinear kernels are implemented in the GenSVM package:
-#' \describe{
-#' \item{RBF}{The Radial Basis Function kernel is a commonly used kernel.
diff --git a/R/gensvm-package.R b/R/gensvm-package.R
index 13c2c31..f664577 100644
--- a/R/gensvm-package.R
+++ b/R/gensvm-package.R
@@ -1,15 +1,15 @@
#' GenSVM: A Generalized Multiclass Support Vector Machine
#'
#' The GenSVM classifier is a generalized multiclass support vector machine
-#' (SVM). This classifier simultaneously aims to find decision boundaries that
-#' separate the classes with as wide a margin as possible. In GenSVM, the loss
-#' functions that measures how misclassifications are counted is very flexible.
-#' This allows the user to tune the classifier to the dataset at hand and
+#' (SVM). This classifier aims to find decision boundaries that separate the
+#' classes with as wide a margin as possible. In GenSVM, the loss functions
+#' that measures how misclassifications are counted is very flexible. This
+#' allows the user to tune the classifier to the dataset at hand and
#' potentially obtain higher classification accuracy. Moreover, this
#' flexibility means that GenSVM has a number of alternative multiclass SVMs as
#' special cases. One of the other advantages of GenSVM is that it is trained
-#' in the primal, allowing the use of warm starts during optimization. This
-#' means that for common tasks such as cross validation or repeated model
+#' in the primal space, allowing the use of warm starts during optimization.
+#' This means that for common tasks such as cross validation or repeated model
#' fitting, GenSVM can be trained very quickly.
#'
#' This package provides functions for training the GenSVM model either as a
@@ -26,19 +26,71 @@
#' GenSVM.}
#' }
#'
-#' Other available functions are:
+#' For the GenSVM and GenSVMGrid models the following two functions are
+#' available. When applied to a GenSVMGrid object, the function is applied to
+#' the best GenSVM model.
#' \describe{
#' \item{\code{\link{plot}}}{Plot the low-dimensional \emph{simplex} space
-#' where the decision boundaries are fixed.}
+#' where the decision boundaries are fixed (for problems with 3 classes).}
#' \item{\code{\link{predict}}}{Predict the class labels of new data using the
#' GenSVM model.}
-#' \item{\code{\link{coef}}}{Get the coefficients of the GenSVM model}
-#' \item{\code{\link{print}}}{Print a short description of the fitted GenSVM
-#' model}
#' }
#'
+#' Moreover, for the GenSVM and GenSVMGrid models a \code{coef} function is
+#' defined:
+#' \describe{
+#' \item{\code{\link{coef.gensvm}}}{Get the coefficients of the fitted GenSVM
+#' model.}
+#' \item{\code{\link{coef.gensvm.grid}}}{Get the parameter grid of the GenSVM
+#' grid search.}
+#' }
+#'
+#' The following utility functions are also included:
+#' \describe{
+#' \item{\code{\link{gensvm.accuracy}}}{Compute the accuracy score between true
+#' and predicted class labels}
+#' \item{\code{\link{gensvm.maxabs.scale}}}{Scale each column of the dataset by
+#' its maximum absolute value, preserving sparsity and mapping the data to [-1,
+#' 1]}
+#' \item{\code{\link{gensvm.train.test.split}}}{Split a dataset into a training
+#' and testing sample}
+#' \item{\code{\link{gensvm.refit}}}{Refit a fitted GenSVM model with slightly
+#' different parameters or on a different dataset}
+#' }
+#'
+#' @section Kernels in GenSVM:
+#'
+#' GenSVM can be used for both linear and nonlinear multiclass support vector
+#' machine classification. In general, linear classification will be faster but
+#' depending on the dataset higher classification performance can be achieved
+#' using a nonlinear kernel.
+#'
+#' The following nonlinear kernels are implemented in the GenSVM package:
+#' \describe{
+#' \item{RBF}{The Radial Basis Function kernel is a well-known kernel function
+#' based on the Euclidean distance between objects. It is defined as
+#' \deqn{
+#' k(x_i, x_j) = exp( -\gamma || x_i - x_j ||^2 )
+#' }
+#' }
+#' \item{Polynomial}{A polynomial kernel can also be used in GenSVM. This
+#' kernel function is implemented very generally and therefore takes three
+#' parameters (\code{coef}, \code{gamma}, and \code{degree}). It is defined
+#' as:
+#' \deqn{
+#' k(x_i, x_j) = ( \gamma x_i' x_j + coef)^{degree}
+#' }
+#' }
+#' \item{Sigmoid}{The sigmoid kernel is the final kernel implemented in
+#' GenSVM. This kernel has two parameters and is implemented as follows:
+#' \deqn{
+#' k(x_i, x_j) = \tanh( \gamma x_i' x_j + coef)
+#' }
+#' }
+#' }
+#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -46,11 +98,10 @@
#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
#'
-#' @examples
-#'
+#' @aliases
+#' gensvm.package
#'
#' @name gensvm-package
#' @docType package
-#' @import
NULL
#>NULL
diff --git a/R/gensvm.R b/R/gensvm.R
index 4a4ab6b..40930ac 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -1,7 +1,8 @@
#' @title Fit the GenSVM model
#'
#' @description Fits the Generalized Multiclass Support Vector Machine model
-#' with the given parameters.
+#' with the given parameters. See the package documentation
+#' (\code{\link{gensvm-package}}) for more general information about GenSVM.
#'
#' @param X data matrix with the predictors
#' @param y class labels
@@ -12,7 +13,8 @@
#' @param weights type of instance weights to use. Options are 'unit' for unit
#' weights and 'group' for group size correction weight (eq. 4 in the paper).
#' @param kernel the kernel type to use in the classifier. It must be one of
-#' 'linear', 'poly', 'rbf', or 'sigmoid'.
+#' 'linear', 'poly', 'rbf', or 'sigmoid'. See the section "Kernels in GenSVM"
+#' in \code{\link{gensvm-package}} for more info.
#' @param gamma kernel parameter for the rbf, polynomial, and sigmoid kernel.
#' If gamma is 'auto', then 1/n_features will be used.
#' @param coef parameter for the polynomial and sigmoid kernel.
@@ -25,25 +27,45 @@
#' @param random.seed Seed for the random number generator (useful for
#' reproducible output)
#' @param max.iter Maximum number of iterations of the optimization algorithm.
+#' @param seed.V Matrix to warm-start the optimization algorithm. This is
+#' typically the output of \code{coef(fit)}. Note that this function will
+#' silently drop seed.V if the dimensions don't match the provided data.
#'
#' @return A "gensvm" S3 object is returned for which the print, predict, coef,
#' and plot methods are available. It has the following items:
#' \item{call}{The call that was used to construct the model.}
+#' \item{p}{The value of the lp norm in the loss function}
#' \item{lambda}{The regularization parameter used in the model.}
#' \item{kappa}{The hinge function parameter used.}
#' \item{epsilon}{The stopping criterion used.}
#' \item{weights}{The instance weights type used.}
#' \item{kernel}{The kernel function used.}
-#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}.
+#' \item{gamma}{The value of the gamma parameter of the kernel, if applicable}
#' \item{coef}{The value of the coef parameter of the kernel, if applicable}
#' \item{degree}{The degree of the kernel, if applicable}
#' \item{kernel.eigen.cutoff}{The cutoff value of the reduced
-#' eigendecomposition of the kernel matrix}
+#' eigendecomposition of the kernel matrix.}
+#' \item{verbose}{Whether or not the model was fitted with progress output}
#' \item{random.seed}{The random seed used to seed the model.}
#' \item{max.iter}{Maximum number of iterations of the algorithm.}
+#' \item{n.objects}{Number of objects in the dataset}
+#' \item{n.features}{Number of features in the dataset}
+#' \item{n.classes}{Number of classes in the dataset}
+#' \item{classes}{Array with the actual class labels}
+#' \item{V}{Coefficient matrix}
+#' \item{n.iter}{Number of iterations performed in training}
+#' \item{n.support}{Number of support vectors in the final model}
+#' \item{training.time}{Total training time}
+#' \item{X.train}{When training with nonlinear kernels, the training data is
+#' needed to perform prediction. For these kernels it is therefore stored in
+#' the fitted model.}
+#'
+#' @note
+#' This function returns partial results when the computation is interrupted by
+#' the user.
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -59,10 +81,37 @@
#' @useDynLib gensvm_wrapper, .registration = TRUE
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # fit using the default parameters
+#' fit <- gensvm(x, y)
+#'
+#' # fit and show progress
+#' fit <- gensvm(x, y, verbose=T)
+#'
+#' # fit with some changed parameters
+#' fit <- gensvm(x, y, lambda=1e-8)
+#'
+#' # Early stopping defined through epsilon
+#' fit <- gensvm(x, y, epsilon=1e-3)
+#'
+#' # Early stopping defined through max.iter
+#' fit <- gensvm(x, y, max.iter=1000)
#'
-gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
- weights='unit', kernel='linear', gamma='auto', coef=0.0,
- degree=2.0, kernel.eigen.cutoff=1e-8, verbose=0,
+#' # Nonlinear training
+#' fit <- gensvm(x, y, kernel='rbf')
+#' fit <- gensvm(x, y, kernel='poly', degree=2, gamma=1.0)
+#'
+#' # Setting the random seed and comparing results
+#' fit <- gensvm(x, y, random.seed=123)
+#' fit2 <- gensvm(x, y, random.seed=123)
+#' all.equal(coef(fit), coef(fit2))
+#'
+#'
+gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
+ weights='unit', kernel='linear', gamma='auto', coef=1.0,
+ degree=2.0, kernel.eigen.cutoff=1e-8, verbose=FALSE,
random.seed=NULL, max.iter=1e8, seed.V=NULL)
{
call <- match.call()
@@ -72,9 +121,6 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
if (is.null(random.seed))
random.seed <- runif(1) * (2**31 - 1)
- # TODO: Store a labelencoder in the object, preferably as a partially
- # hidden item. This can then be used with prediction.
-
n.objects <- nrow(X)
n.features <- ncol(X)
n.classes <- length(unique(y))
@@ -90,17 +136,23 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
if (length(weight.idx) == 0) {
- stop("Incorrect weight specification. ",
+ cat("Error: Incorrect weight specification. ",
"Valid options are 'unit' and 'group'")
+ return
}
# Convert kernel to index (remember off-by-one for R vs. C)
kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
if (length(kernel.idx) == 0) {
- stop("Incorrect kernel specification. ",
+ cat("Error: Incorrect kernel specification. ",
"Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
+ return
}
+ seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V)
+ seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V)
+
+ # Call the C train routine
out <- .Call("R_gensvm_train",
as.matrix(X),
as.integer(y.clean),
@@ -118,18 +170,24 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
as.integer(max.iter),
as.integer(random.seed),
seed.V,
+ as.integer(seed.rows),
+ as.integer(seed.cols),
as.integer(n.objects),
as.integer(n.features),
as.integer(n.classes))
+ # build the output object
object <- list(call = call, p = p, lambda = lambda, kappa = kappa,
epsilon = epsilon, weights = weights, kernel = kernel,
gamma = gamma, coef = coef, degree = degree,
kernel.eigen.cutoff = kernel.eigen.cutoff,
- random.seed = random.seed, max.iter = max.iter,
- n.objects = n.objects, n.features = n.features,
- n.classes = n.classes, classes = classes, V = out$V,
- n.iter = out$n.iter, n.support = out$n.support)
+ verbose = verbose, random.seed = random.seed,
+ max.iter = max.iter, n.objects = n.objects,
+ n.features = n.features, n.classes = n.classes,
+ classes = classes, V = out$V, n.iter = out$n.iter,
+ n.support = out$n.support,
+ training.time = out$training.time,
+ X.train = if(kernel == 'linear') NULL else X)
class(object) <- "gensvm"
return(object)
diff --git a/R/gensvm.accuracy.R b/R/gensvm.accuracy.R
new file mode 100644
index 0000000..dbcd3cc
--- /dev/null
+++ b/R/gensvm.accuracy.R
@@ -0,0 +1,37 @@
+#' @title Compute the accuracy score
+#'
+#' @param y.true vector of true labels
+#' @param y.pred vector of predicted labels
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @seealso
+#' \code{\link{predict.gensvm.grid}}
+#'
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' fit <- gensvm(x, y)
+#' gensvm.accuracy(predict(fit, x), y)
+#'
+gensvm.accuracy <- function(y.true, y.pred)
+{
+ n <- length(y.true)
+ if (n != length(y.pred)) {
+ cat("Error: Can't compute accuracy if vector don't have the ",
+ "same length\n")
+ return
+ }
+
+ return (sum(y.true == y.pred) / n)
+}
diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R
index 37e2f7f..5d27fde 100644
--- a/R/gensvm.grid.R
+++ b/R/gensvm.grid.R
@@ -3,44 +3,17 @@
#' @description This function performs a cross-validated grid search of the
#' model parameters to find the best hyperparameter configuration for a given
#' dataset. This function takes advantage of GenSVM's ability to use warm
-#' starts to speed up computation. The function also uses the GenSVM C library
-#' for speed.
-#'
-#' There are two ways to use this function: either by providing a data frame
-#' with the parameter configurations to try or by giving each of the function
-#' inputs a vector of values to evaluate. In the latter case all combinations
-#' of the provided values will be used (i.e. the product set).
+#' starts to speed up computation. The function uses the GenSVM C library for
+#' speed.
#'
#' @param X training data matrix. We denote the size of this matrix by
#' n_samples x n_features.
#' @param y training vector of class labes of length n_samples. The number of
#' unique labels in this vector is denoted by n_classes.
-#' @param df Data frame with parameter configurations to evaluate.
-#' If this is provided it overrides the other parameter ranges provided. The
-#' data frame must provide *all* required columns, as described below.
-#' @param p vector of values to try for the \eqn{p} hyperparameter
-#' for the \eqn{\ell_p} norm in the loss function. All values should be on the
-#' interval [1.0, 2.0].
-#' @param lambda vector of values for the regularization parameter
-#' \eqn{\lambda} in the loss function. All values should be larger than 0.
-#' @param kappa vector of values for the hinge function parameter in
-#' the loss function. All values should be larger than -1.
-#' @param weights vector of values for the instance weights. Values
-#' should be either 'unit', 'group', or both.
-#' @param kernel vector of values for the kernel type. Possible
-#' values are: 'linear', 'rbf', 'poly', or 'sigmoid', or any combination of
-#' these values. See the article \link[=gensvm-kernels]{Kernels in GenSVM} for
-#' more information.
-#' @param gamma kernel parameter for the 'rbf', 'poly', and 'sigmoid' kernels.
-#' If it is 'auto', 1/n_features will be used. See the article
-#' \link[=gensvm-kernels]{Kernels in GenSVM} for more information.
-#' @param coef kernel parameter for the 'poly' and 'sigmoid'
-#' kernels. See the article \link[=gensvm-kernels]{Kernels in GenSVM} for more
-#' information.
-#' @param degree kernel parameter for the 'poly' kernel. See the
-#' article \link[=gensvm-kernels]{Kernels in GenSVM} for more information.
-#' @param max.iter maximum number of iterations to run in the
-#' optimization algorithm.
+#' @param param.grid String (\code{'tiny'}, \code{'small'}, or \code{'full'})
+#' or data frame with parameter configurations to evaluate. Typically this is
+#' the output of \code{expand.grid}. For more details, see "Using a Parameter
+#' Grid" below.
#' @param refit boolean variable. If true, the best model from cross validation
#' is fitted again on the entire dataset.
#' @param scoring metric to use to evaluate the classifier performance during
@@ -49,29 +22,94 @@
#' values are better. If it is NULL, the accuracy score will be used.
#' @param cv the number of cross-validation folds to use or a vector with the
#' same length as \code{y} where each unique value denotes a test split.
-#' @param verbose boolean variable to indicate whether training details should
-#' be printed.
+#' @param verbose integer to indicate the level of verbosity (higher is more
+#' verbose)
+#' @param return.train.score whether or not to return the scores on the
+#' training splits
#'
#' @return A "gensvm.grid" S3 object with the following items:
+#' \item{call}{Call that produced this object}
+#' \item{param.grid}{Sorted version of the parameter grid used in training}
#' \item{cv.results}{A data frame with the cross validation results}
#' \item{best.estimator}{If refit=TRUE, this is the GenSVM model fitted with
#' the best hyperparameter configuration, otherwise it is NULL}
-#' \item{best.score}{Mean cross-validated score for the model with the best
-#' hyperparameter configuration}
+#' \item{best.score}{Mean cross-validated test score for the model with the
+#' best hyperparameter configuration}
#' \item{best.params}{Parameter configuration that provided the highest mean
-#' cross-validated score}
+#' cross-validated test score}
#' \item{best.index}{Row index of the cv.results data frame that corresponds to
#' the best hyperparameter configuration}
#' \item{n.splits}{The number of cross-validation splits}
+#' \item{n.objects}{The number of instances in the data}
+#' \item{n.features}{The number of features of the data}
+#' \item{n.classes}{The number of classes in the data}
+#' \item{classes}{Array with the unique classes in the data}
+#' \item{total.time}{Training time for the grid search}
+#' \item{cv.idx}{Array with cross validation indices used to split the data}
+#'
+#' @section Using a Parameter Grid:
+#' To evaluate certain paramater configurations, a data frame can be supplied
+#' to the \code{param.grid} argument of the function. Such a data frame can
+#' easily be generated using the R function \code{expand.grid}, or could be
+#' created through other ways to test specific parameter configurations.
+#'
+#' Three parameter grids are predefined:
+#' \describe{
+#' \item{\code{'tiny'}}{This parameter grid is generated by the function
+#' \code{\link{gensvm.load.tiny.grid}} and is the default parameter grid. It
+#' consists of parameter configurations that are likely to perform well on
+#' various datasets.}
+#' \item{\code{'small'}}{This grid is generated by
+#' \code{\link{gensvm.load.small.grid}} and generates a data frame with 90
+#' configurations. It is typically fast to train but contains some
+#' configurations that are unlikely to perform well. It is included for
+#' educational purposes.}
+#' \item{\code{'full'}}{This grid loads the parameter grid as used in the
+#' GenSVM paper. It consists of 342 configurations and is generated by the
+#' \code{\link{gensvm.load.full.grid}} function. Note that in the GenSVM paper
+#' cross validation was done with this parameter grid, but the final training
+#' step used \code{epsilon=1e-8}. The \code{\link{gensvm.refit}} function is
+#' useful in this scenario.}
+#' }
#'
+#' When you provide your own parameter grid, beware that only certain column
+#' names are allowed in the data frame corresponding to parameters for the
+#' GenSVM model. These names are:
#'
+#' \describe{
+#' \item{p}{Parameter for the lp norm. Must be in [1.0, 2.0].}
+#' \item{kappa}{Parameter for the Huber hinge function. Must be larger than
+#' -1.}
+#' \item{lambda}{Parameter for the regularization term. Must be larger than 0.}
+#' \item{weight}{Instance weight specification. Allowed values are "unit" for
+#' unit weights and "group" for group-size correction weights}
+#' \item{epsilon}{Stopping parameter for the algorithm. Must be larger than 0.}
+#' \item{max.iter}{Maximum number of iterations of the algorithm. Must be
+#' larger than 0.}
+#' \item{kernel}{The kernel to used, allowed values are "linear", "poly",
+#' "rbf", and "sigmoid". The default is "linear"}
+#' \item{coef}{Parameter for the "poly" and "sigmoid" kernels. See the section
+#' "Kernels in GenSVM" in the code{ink{gensvm-package}} page for more info.}
+#' \item{degree}{Parameter for the "poly" kernel. See the section "Kernels in
+#' GenSVM" in the code{ink{gensvm-package}} page for more info.}
+#' \item{gamma}{Parameter for the "poly", "rbf", and "sigmoid" kernels. See the
+#' section "Kernels in GenSVM" in the code{ink{gensvm-package}} page for more
+#' info.}
+#' }
#'
-#' @section Using a DataFrame:
-#' ...
+#' For variables that are not present in the \code{param.grid} data frame the
+#' default parameter values in the \code{\link{gensvm}} function will be used.
#'
+#' Note that this function reorders the parameter grid to make the warm starts
+#' as efficient as possible, which is why the param.grid in the result will not
+#' be the same as the param.grid in the input.
+#'
+#' @note
+#' This function returns partial results when the computation is interrupted by
+#' the user.
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -80,37 +118,499 @@
#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
#'
#' @seealso
-#' \code{\link{coef}}, \code{\link{print}}, \code{\link{predict}},
-#' \code{\link{plot}}, and \code{\link{gensvm.grid}}.
-#'
+#' \code{\link{predict.gensvm.grid}}, \code{\link{print.gensvm.grid}}, and
+#' \code{\link{gensvm}}.
#'
#' @export
#'
#' @examples
-#' X <-
-#'
-
-gensvm.grid <- function(X, y,
- df=NULL,
- p=c(1.0, 1.5, 2.0),
- lambda=c(1e-8, 1e-6, 1e-4, 1e-2, 1),
- kappa=c(-0.9, 0.5, 5.0),
- weights=c('unit', 'group'),
- kernel=c('linear'),
- gamma=c('auto'),
- coef=c(0.0),
- degree=c(2.0),
- max.iter=c(1e8),
- refit=TRUE,
- scoring=NULL,
- cv=3,
- verbose=TRUE)
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # use the default parameter grid
+#' grid <- gensvm.grid(x, y)
+#'
+#' # use a smaller parameter grid
+#' pg <- expand.grid(p=c(1.0, 1.5, 2.0), kappa=c(-0.9, 1.0), epsilon=c(1e-3))
+#' grid <- gensvm.grid(x, y, param.grid=pg)
+#'
+#' # print the result
+#' print(grid)
+#'
+#' # Using a custom scoring function (accuracy as percentage)
+#' acc.pct <- function(yt, yp) { return (100 * sum(yt == yp) / length(yt)) }
+#' grid <- gensvm.grid(x, y, scoring=acc.pct)
+#'
+gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3,
+ verbose=0, return.train.score=TRUE)
{
call <- match.call()
+ n.objects <- nrow(X)
+ n.features <- ncol(X)
+ n.classes <- length(unique(y))
+
+ if (is.character(param.grid)) {
+ if (param.grid == 'tiny') {
+ param.grid <- gensvm.load.tiny.grid()
+ } else if (param.grid == 'small') {
+ param.grid <- gensvm.load.small.grid()
+ } else if (param.grid == 'full') {
+ param.grid <- gensvm.load.full.grid()
+ }
+ }
+
+ # Validate the range of the values for the gridsearch
+ gensvm.validate.param.grid(param.grid)
+
+ # Sort the parameter grid for efficient warm starts
+ param.grid <- gensvm.sort.param.grid(param.grid)
+
+ # Expand and convert the parameter grid for use in the C function
+ C.param.grid <- gensvm.expand.param.grid(param.grid, n.features)
+
+ # Convert labels to integers
+ classes <- sort(unique(y))
+ y.clean <- match(y, classes)
+
+ if (is.vector(cv) && length(cv) == n.objects) {
+ folds <- sort(unique(cv))
+ cv.idx <- match(cv, folds) - 1
+ n.splits <- length(folds)
+ } else {
+ cv.idx <- gensvm.generate.cv.idx(n.objects, cv[1])
+ n.splits <- cv
+ }
+
+ results <- .Call("R_gensvm_grid",
+ as.matrix(X),
+ as.integer(y.clean),
+ as.matrix(C.param.grid),
+ as.integer(nrow(C.param.grid)),
+ as.integer(ncol(C.param.grid)),
+ as.integer(cv.idx),
+ as.integer(n.splits),
+ as.integer(verbose),
+ as.integer(n.objects),
+ as.integer(n.features),
+ as.integer(n.classes)
+ )
+
+ cv.results <- gensvm.cv.results(results, param.grid, cv.idx,
+ y.clean, scoring,
+ return.train.score=return.train.score)
+ best.index <- which.min(cv.results$rank.test.score)[1]
+ if (!is.na(best.index)) { # can occur when user interrupts
+ best.score <- cv.results$mean.test.score[best.index]
+ best.params <- param.grid[best.index, , drop=F]
+ # Remove duplicate attributes from best.params
+ attr(best.params, "out.attrs") <- NULL
+ } else {
+ best.score <- NA
+ best.params <- list()
+ }
- object <- list(...)
+ if (refit && !is.na(best.index)) {
+ gensvm.args <- as.list(best.params)
+ gensvm.args$X <- X
+ gensvm.args$y <- y
+ best.estimator <- do.call(gensvm, gensvm.args)
+ } else {
+ best.estimator <- NULL
+ }
+
+ object <- list(call = call, param.grid = param.grid,
+ cv.results = cv.results, best.estimator = best.estimator,
+ best.score = best.score, best.params = best.params,
+ best.index = best.index, n.splits = n.splits,
+ n.objects = n.objects, n.features = n.features,
+ n.classes = n.classes, classes = classes,
+ total.time = results$total.time, cv.idx = cv.idx)
class(object) <- "gensvm.grid"
return(object)
}
+
+#' @title Load a tiny parameter grid for the GenSVM grid search
+#'
+#' @description This function returns a parameter grid to use in the GenSVM
+#' grid search. This grid was obtained by analyzing the experiments done for
+#' the GenSVM paper and selecting the configurations that achieve accuracy
+#' within the 95th percentile on over 90% of the datasets. It is a good start
+#' for a parameter search with a reasonably high chance of achieving good
+#' performance on most datasets.
+#'
+#' Note that this grid is only tested to work well in combination with the
+#' linear kernel.
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @export
+#'
+#' @seealso
+#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.small.grid}},
+#' \code{\link{gensvm.load.full.grid}}.
+#'
+gensvm.load.tiny.grid <- function()
+{
+ df <- data.frame(
+ p=c(2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 2.0, 2.0),
+ kappa=c(5.0, 5.0, 0.5, 5.0, -0.9, 5.0, 0.5, -0.9, 0.5, 0.5),
+ lambda = c(2^-16, 2^-18, 2^-18, 2^-18, 2^-18, 2^-14, 2^-18,
+ 2^-18, 2^-16, 2^-16),
+ weight = c('unit', 'unit', 'unit', 'group', 'unit',
+ 'unit', 'group', 'unit', 'unit', 'group')
+ )
+ return(df)
+}
+
+#' @title Load a large parameter grid for the GenSVM grid search
+#'
+#' @description This loads the parameter grid from the GenSVM paper. It
+#' consists of 342 configurations and is constructed from all possible
+#' combinations of the following parameter sets:
+#'
+#' \code{p = c(1.0, 1.5, 2.0)}
+#'
+#' \code{lambda = 2^seq(-18, 18, 2)}
+#'
+#' \code{kappa = c(-0.9, 0.5, 5.0)}
+#'
+#' \code{weight = c('unit', 'group')}
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @export
+#'
+#' @seealso
+#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.tiny.grid}},
+#' \code{\link{gensvm.load.full.grid}}.
+#'
+gensvm.load.full.grid <- function()
+{
+ df <- expand.grid(p=c(1.0, 1.5, 2.0), lambda=2^seq(-18, 18, 2),
+ kappa=c(-0.9, 0.5, 5.0), weight=c('unit', 'group'),
+ epsilon=c(1e-6))
+ return(df)
+}
+
+
+#' @title Load the default parameter grid for the GenSVM grid search
+#'
+#' @description This function loads a default parameter grid to use for the
+#' GenSVM gridsearch. It contains all possible combinations of the following
+#' parameter sets:
+#'
+#' \code{p = c(1.0, 1.5, 2.0)}
+#'
+#' \code{lambda = c(1e-8, 1e-6, 1e-4, 1e-2, 1)}
+#'
+#' \code{kappa = c(-0.9, 0.5, 5.0)}
+#'
+#' \code{weight = c('unit', 'group')}
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @export
+#'
+#' @seealso
+#' \code{\link{gensvm.grid}}, \code{\link{gensvm.load.tiny.grid}},
+#' \code{\link{gensvm.load.small.grid}}.
+#'
+gensvm.load.small.grid <- function()
+{
+ df <- expand.grid(p=c(1.0, 1.5, 2.0), lambda=c(1e-8, 1e-6, 1e-4, 1e-2, 1),
+ kappa=c(-0.9, 0.5, 5.0), weight=c('unit', 'group'))
+ return(df)
+}
+
+
+#' Generate a vector of cross-validation indices
+#'
+#' This function generates a vector of length \code{n} with values from 0 to
+#' \code{folds-1} to mark train and test splits.
+#'
+gensvm.generate.cv.idx <- function(n, folds)
+{
+ cv.idx <- matrix(0, n, 1)
+
+ big.folds <- n %% folds
+ small.fold.size <- n %/% folds
+
+ j <- 0
+ for (i in 0:(small.fold.size * folds)) {
+ while (TRUE) {
+ idx <- round(runif(1, 1, n))
+ if (cv.idx[idx] == 0) {
+ cv.idx[idx] <- j
+ j <- j + 1
+ j <- (j %% folds)
+ break
+ }
+ }
+ }
+
+ j <- 1
+ i <- 0
+ while (i < big.folds) {
+ if (cv.idx[j] == 0) {
+ cv.idx[j] <- i
+ i <- i + 1
+ }
+ j <- j + 1
+ }
+
+ return(cv.idx)
+}
+
+gensvm.validate.param.grid <- function(df)
+{
+ expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight",
+ "kappa", "lambda", "p", "epsilon", "max.iter")
+ for (name in colnames(df)) {
+ if (!(name %in% expected.colnames)) {
+ stop("Invalid header name supplied in parameter grid: ", name)
+ }
+ }
+
+ conditions <- list(
+ p=function(x) { x >= 1.0 && x <= 2.0 },
+ kappa=function(x) { x > -1.0 },
+ lambda=function(x) {x > 0.0 },
+ epsilon=function(x) { x > 0.0 },
+ gamma=function(x) { x != 0.0 },
+ weight=function(x) { x %in% c("unit", "group") },
+ kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
+ )
+
+ for (idx in 1:nrow(df)) {
+ for (param in colnames(df)) {
+ if (!(param %in% names(conditions)))
+ next
+ func <- conditions[[param]]
+ value <- df[[param]][idx]
+ if (!func(value))
+ stop("Invalid value in grid for parameter: ", param)
+ }
+ }
+}
+
+gensvm.cv.results <- function(results, param.grid, cv.idx, y.true,
+ scoring, return.train.score=TRUE)
+{
+ n.candidates <- nrow(param.grid)
+ n.splits <- length(unique(cv.idx))
+
+ score <- if(is.function(scoring)) scoring else gensvm.accuracy
+
+ # Build names and initialize the data.frame
+ names <- c("mean.fit.time", "mean.score.time", "mean.test.score")
+ if (return.train.score)
+ names <- c(names, "mean.train.score")
+ for (param in names(param.grid)) {
+ names <- c(names, sprintf("param.%s", param))
+ }
+ names <- c(names, "rank.test.score")
+ for (idx in sort(unique(cv.idx))) {
+ names <- c(names, sprintf("split%i.test.score", idx))
+ if (return.train.score)
+ names <- c(names, sprintf("split%i.train.score", idx))
+ }
+ names <- c(names, "std.fit.time", "std.score.time", "std.test.score")
+ if (return.train.score)
+ names <- c(names, "std.train.score")
+
+ df <- data.frame(matrix(ncol=length(names), nrow=n.candidates))
+ colnames(df) <- names
+
+ for (pidx in 1:n.candidates) {
+ param <- param.grid[pidx, , drop=F]
+ durations <- results$durations[pidx, ]
+ predictions <- results$predictions[pidx, ]
+
+ fit.times <- durations
+ score.times <- c()
+ test.scores <- c()
+ train.scores <- c()
+
+ is.missing <- any(is.na(durations))
+
+ for (test.idx in sort(unique(cv.idx))) {
+ score.time <- 0
+
+ if (return.train.score) {
+ y.train.pred <- predictions[cv.idx != test.idx]
+ y.train.true <- y.true[cv.idx != test.idx]
+
+ start.time <- proc.time()
+ train.score <- score(y.train.true, y.train.pred)
+ stop.time <- proc.time()
+ score.time <- score.time + (stop.time - start.time)[3]
+
+ train.scores <- c(train.scores, train.score)
+ }
+
+ y.test.pred <- predictions[cv.idx == test.idx]
+ y.test.true <- y.true[cv.idx == test.idx]
+
+ start.time <- proc.time()
+ test.score <- score(y.test.true, y.test.pred)
+ stop.time <- proc.time()
+ score.time <- score.time + (stop.time - start.time)[3]
+
+ test.scores <- c(test.scores, test.score)
+
+ score.times <- c(score.times, score.time)
+ }
+
+ df$mean.fit.time[pidx] <- mean(fit.times)
+ df$mean.score.time[pidx] <- if(is.missing) NA else mean(score.times)
+ df$mean.test.score[pidx] <- mean(test.scores)
+ df$std.fit.time[pidx] <- sd(fit.times)
+ df$std.score.time[pidx] <- if(is.missing) NA else sd(score.times)
+ df$std.test.score[pidx] <- sd(test.scores)
+ if (return.train.score) {
+ df$mean.train.score[pidx] <- mean(train.scores)
+ df$std.train.score[pidx] <- sd(train.scores)
+ }
+
+ for (parname in names(param.grid)) {
+ df[[sprintf("param.%s", parname)]][pidx] <- param[[parname]]
+ }
+
+ j <- 1
+ for (test.idx in sort(unique(cv.idx))) {
+ lbl <- sprintf("split%i.test.score", test.idx)
+ df[[lbl]][pidx] <- test.scores[j]
+ if (return.train.score) {
+ lbl <- sprintf("split%i.train.score", test.idx)
+ df[[lbl]][pidx] <- train.scores[j]
+ }
+ j <- j + 1
+ }
+ }
+
+ df$rank.test.score <- gensvm.rank.score(df$mean.test.score)
+
+ return(df)
+}
+
+gensvm.sort.param.grid <- function(param.grid)
+{
+ all.cols <- c("kernel", "coef", "degree", "gamma", "weight", "kappa",
+ "lambda", "p", "epsilon", "max.iter")
+
+ order.args <- NULL
+ for (name in all.cols) {
+ if (name %in% colnames(param.grid)) {
+ if (name == "epsilon")
+ order.args <- cbind(order.args, -param.grid[[name]])
+ else
+ order.args <- cbind(order.args, param.grid[[name]])
+ }
+ }
+ sorted.pg <- param.grid[do.call(order, as.list(as.data.frame(order.args))), ]
+
+ rownames(sorted.pg) <- NULL
+
+ return(sorted.pg)
+}
+
+gensvm.expand.param.grid <- function(pg, n.features)
+{
+ if ("kernel" %in% colnames(pg)) {
+ all.kernels <- c("linear", "poly", "rbf", "sigmoid")
+ pg$kernel <- match(pg$kernel, all.kernels) - 1
+ } else {
+ pg$kernel <- 0
+ }
+
+ if ("weight" %in% colnames(pg)) {
+ all.weights <- c("unit", "group")
+ pg$weight <- match(pg$weight, all.weights)
+ } else {
+ pg$weight <- 1
+ }
+
+ if ("gamma" %in% colnames(pg)) {
+ pg$gamma[pg$gamma == "auto"] <- 1.0/n.features
+ } else {
+ pg$gamma <- 1.0/n.features
+ }
+
+ if (!("degree" %in% colnames(pg)))
+ pg$degree <- 2.0
+ if (!("coef" %in% colnames(pg)))
+ pg$coef <- 0.0
+ if (!("p" %in% colnames(pg)))
+ pg$p <- 1.0
+ if (!("lambda" %in% colnames(pg)))
+ pg$lambda <- 1e-8
+ if (!("kappa" %in% colnames(pg)))
+ pg$kappa <- 0.0
+ if (!("epsilon" %in% colnames(pg)))
+ pg$epsilon <- 1e-6
+ if (!("max.iter" %in% colnames(pg)))
+ pg$max.iter <- 1e8
+
+ C.param.grid <- data.frame(kernel=pg$kernel, coef=pg$coef,
+ degree=pg$degree, gamma=pg$gamma,
+ weight=pg$weight, kappa=pg$kappa,
+ lambda=pg$lambda, p=pg$p, epsilon=pg$epsilon,
+ max.iter=pg$max.iter)
+
+ return(C.param.grid)
+}
+
+#' @title Compute the ranks for the numbers in a given vector
+#'
+#' @details
+#' This function computes the ranks for the values in an array. The highest
+#' value gets the smallest rank. Ties are broken by assigning the smallest
+#' value.
+#'
+#' @param x array of numeric values
+#'
+#' @examples
+#' x <- c(7, 0.1, 0.5, 0.1, 10, 100, 200)
+#' gensvm.rank.score(x)
+#' [ 4 6 5 6 3 2 1 ]
+#'
+gensvm.rank.score <- function(x)
+{
+ x <- as.array(x)
+ l <- length(x)
+ r <- 1
+ ranks <- as.vector(matrix(0, l, 1))
+ ranks[which(is.na(x))] <- NA
+ while (!all(mapply(is.na, x))) {
+ m <- max(x, na.rm=T)
+ idx <- which(x == m)
+ ranks[idx] <- r
+ r <- r + length(idx)
+ x[idx] <- NA
+ }
+
+ return(ranks)
+}
diff --git a/R/gensvm.maxabs.scale.R b/R/gensvm.maxabs.scale.R
new file mode 100644
index 0000000..6ac351b
--- /dev/null
+++ b/R/gensvm.maxabs.scale.R
@@ -0,0 +1,77 @@
+#' @title Scale each column of a matrix by its maximum absolute value
+#'
+#' @description Scaling a dataset can creatly decrease the computation time of
+#' GenSVM. This function scales the data by dividing each column of a matrix by
+#' the maximum absolute value of that column. This preserves sparsity in the
+#' data while mapping each column to the interval [-1, 1].
+#'
+#' Optionally a test dataset can be provided as well. In this case, the scaling
+#' will be computed on the first argument (\code{x}) and applied to the test
+#' dataset. Note that the return value is a list when this argument is
+#' supplied.
+#'
+#' @param x a matrix to scale
+#' @param x.test (optional) a test matrix to scale as well.
+#'
+#' @return if x.test=NULL a scaled matrix where the maximum value of the
+#' columns is 1 and the minimum value of the columns isn't below -1. If x.test
+#' is supplied, a list with elements \code{x} and \code{x.test} representing
+#' the scaled datasets.
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#'
+#' # check the min and max of the columns
+#' apply(x, 2, min)
+#' apply(x, 2, max)
+#'
+#' # scale the data
+#' x.scale <- gensvm.maxabs.scale(x)
+#'
+#' # check again (max should be 1.0, min shouldn't be below -1)
+#' apply(x.scale, 2, min)
+#' apply(x.scale, 2, max)
+#'
+#' # with a train and test dataset
+#' x <- iris[, -5]
+#' split <- gensvm.train.test.split(x)
+#' x.train <- split$x.train
+#' x.test <- split$x.test
+#' scaled <- gensvm.maxabs.scale(x.train, x.test)
+#' x.train.scl <- scaled$x
+#' x.test.scl <- scaled$x.test
+#'
+gensvm.maxabs.scale <- function(x, x.test=NULL)
+{
+ xm <- as.matrix(x)
+ max.abs <- apply(apply(xm, 2, abs), 2, max)
+ max.abs[max.abs == 0] <- 1
+
+ scaled <- xm %*% diag(1.0 / max.abs)
+ colnames(scaled) <- colnames(x)
+ rownames(scaled) <- rownames(x)
+
+ if (!is.null(x.test)) {
+ xtm <- as.matrix(x.test)
+ scaled.test <- xtm %*% diag(1.0 / max.abs)
+ colnames(scaled.test) <- colnames(x.test)
+ rownames(scaled.test) <- rownames(x.test)
+
+ ret.val <- list(x=scaled, x.test=scaled.test)
+ } else {
+ ret.val <- scaled
+ }
+
+ return(ret.val)
+}
diff --git a/R/gensvm.refit.R b/R/gensvm.refit.R
new file mode 100644
index 0000000..a6af3fd
--- /dev/null
+++ b/R/gensvm.refit.R
@@ -0,0 +1,82 @@
+#' @title Train an already fitted model on new data
+#'
+#' @title This function can be used to train an existing model on new data or
+#' fit an existing model with slightly different parameters. It is useful for
+#' retraining without having to copy all the parameters over. One common
+#' application for this is to refit the best model found by a grid search, as
+#' illustrated in the examples.
+#'
+#' @param fit Fitted \code{gensvm} object
+#' @param X Data matrix of the new data
+#' @param y Label vector of the new data
+#' @param verbose Turn on verbose output and fit progress. If NULL (the
+#' default) the value from the fitted model is chosen.
+#'
+#' @return a new fitted \code{gensvm} model
+#'
+#' @export
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # fit a standard model and refit with slightly different parameters
+#' fit <- gensvm(x, y)
+#' fit2 <- gensvm.refit(x, y, epsilon=1e-8)
+#'
+#' # refit a model returned by a grid search
+#' grid <- gensvm.grid(x, y)
+#' fit <- gensvm.refit(fit, x, y, epsilon=1e-8)
+#'
+#' # refit on different data
+#' idx <- runif(nrow(x)) > 0.5
+#' x1 <- x[idx, ]
+#' x2 <- x[!idx, ]
+#' y1 <- y[idx]
+#' y2 <- y[!idx]
+#'
+#' fit1 <- gensvm(x1, y1)
+#' fit2 <- gensvm.refit(fit1, x2, y2)
+#'
+gensvm.refit <- function(fit, X, y, p=NULL, lambda=NULL, kappa=NULL,
+ epsilon=NULL, weights=NULL, kernel=NULL, gamma=NULL,
+ coef=NULL, degree=NULL, kernel.eigen.cutoff=NULL,
+ max.iter=NULL, verbose=NULL, random.seed=NULL)
+{
+ p <- if(is.null(p)) fit$p else p
+ lambda <- if(is.null(lambda)) fit$lambda else lambda
+ kappa <- if(is.null(kappa)) fit$kappa else kappa
+ epsilon <- if(is.null(epsilon)) fit$epsilon else epsilon
+ weights <- if(is.null(weights)) fit$weights else weights
+ kernel <- if(is.null(kernel)) fit$kernel else kernel
+ gamma <- if(is.null(gamma)) fit$gamma else gamma
+ coef <- if(is.null(coef)) fit$coef else coef
+ degree <- if(is.null(degree)) fit$degree else degree
+ kernel.eigen.cutoff <- (if(is.null(kernel.eigen.cutoff))
+ fit$kernel.eigen.cutoff else kernel.eigen.cutoff)
+ max.iter <- if(is.null(max.iter)) fit$max.iter else max.iter
+ verbose <- if(is.null(verbose)) fit$verbose else verbose
+ random.seed <- if(is.null(random.seed)) fit$random.seed else random.seed
+
+ # Setting the error handler here is necessary in case the user interrupts
+ # this call to gensvm. If we don't set the error handler, R will
+ # unnecessarily drop to a browser() session. We reset the error handler
+ # after the call to gensvm().
+ options(error=function() {})
+ newfit <- gensvm(X, y, p=p, lambda=lambda, kappa=kappa, epsilon=epsilon,
+ weights=weights, kernel=kernel, gamma=gamma, coef=coef,
+ degree=degree, kernel.eigen.cutoff=kernel.eigen.cutoff,
+ verbose=verbose, max.iter=max.iter, seed.V=coef(fit))
+ options(error=NULL)
+
+ return(newfit)
+}
diff --git a/R/gensvm.train.test.split.R b/R/gensvm.train.test.split.R
new file mode 100644
index 0000000..406f80e
--- /dev/null
+++ b/R/gensvm.train.test.split.R
@@ -0,0 +1,121 @@
+#' @title Create a train/test split of a dataset
+#'
+#' @description Often it is desirable to split a dataset into a training and
+#' testing sample. This function is included in GenSVM to make it easy to do
+#' so. The function is inspired by a similar function in Scikit-Learn.
+#'
+#' @param x array to split
+#' @param y another array to split (typically this is a vector)
+#' @param train.size size of the training dataset. This can be provided as
+#' float or as int. If it's a float, it should be between 0.0 and 1.0 and
+#' represents the fraction of the dataset that should be placed in the training
+#' dataset. If it's an int, it represents the exact number of samples in the
+#' training dataset. If it is NULL, the complement of \code{test.size} will be
+#' used.
+#' @param test.size size of the test dataset. Similarly to train.size both a
+#' float or an int can be supplied. If it's NULL, the complement of train.size
+#' will be used. If both train.size and test.size are NULL, a default test.size
+#' of 0.25 will be used.
+#' @param shuffle shuffle the rows or not
+#' @param random.state seed for the random number generator (int)
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # using the default values
+#' split <- gensvm.train.test.split(x, y)
+#'
+#' # using the split in a GenSVM model
+#' fit <- gensvm(split$x.train, split$y.train)
+#' gensvm.accuracy(split$y.test, predict(fit, split$x.test))
+#'
+#' # using attach makes the results directly available
+#' attach(gensvm.train.test.split(x, y))
+#' fit <- gensvm(x.train, y.train)
+#' gensvm.accuracy(y.test, predict(fit, x.test))
+#'
+gensvm.train.test.split <- function(x, y=NULL, train.size=NULL, test.size=NULL,
+ shuffle=TRUE, random.state=NULL,
+ return.idx=FALSE)
+{
+ if (!is.null(y) && dim(as.matrix(x))[1] != dim(as.matrix(y))[1]) {
+ cat("Error: First dimension of x and y should be equal.\n")
+ return
+ }
+
+ n.objects <- dim(as.matrix(x))[1]
+
+ if (is.null(train.size) && is.null(test.size)) {
+ test.size <- round(0.25 * n.objects)
+ train.size <- n.objects - test.size
+ }
+ else if (is.null(train.size)) {
+ if (test.size > 0.0 && test.size < 1.0)
+ test.size <- round(n.objects * test.size)
+ train.size <- n.objects - test.size
+ }
+ else if (is.null(test.size)) {
+ if (train.size > 0.0 && train.size < 1.0)
+ train.size <- round(n.objects * train.size)
+ test.size <- n.objects - train.size
+ }
+ else {
+ if (train.size > 0.0 && train.size < 1.0)
+ train.size <- round(n.objects * train.size)
+ if (test.size > 0.0 && test.size < 1.0)
+ test.size <- round(n.objects * test.size)
+ }
+
+ if (!is.null(random.state))
+ set.seed(random.state)
+
+ if (shuffle) {
+ train.idx <- sample(n.objects, train.size)
+ diff <- setdiff(1:n.objects, train.idx)
+ test.idx <- sample(diff, test.size)
+ } else {
+ train.idx <- 1:train.size
+ diff <- setdiff(1:n.objects, train.idx)
+ test.idx <- diff[1:test.size]
+ }
+
+ x.train <- x[train.idx, ]
+ x.test <- x[test.idx, ]
+
+ if (!is.null(y)) {
+ if (is.matrix(y)) {
+ y.train <- y[train.idx, ]
+ y.test <- y[test.idx, ]
+ } else {
+ y.train <- y[train.idx]
+ y.test <- y[test.idx]
+ }
+ }
+
+ out <- list(
+ x.train = x.train,
+ x.test = x.test
+ )
+ if (!is.null(y)) {
+ out$y.train <- y.train
+ out$y.test <- y.test
+ }
+ if (return.idx) {
+ out$idx.train <- train.idx
+ out$idx.test <- test.idx
+ }
+
+ return(out)
+}
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
new file mode 100644
index 0000000..0ce215b
--- /dev/null
+++ b/R/plot.gensvm.R
@@ -0,0 +1,199 @@
+#' @title Plot the simplex space of the fitted GenSVM model
+#'
+#' @description This function creates a plot of the simplex space for a fitted
+#' GenSVM model and the given data set, as long as the dataset consists of only
+#' 3 classes. For more than 3 classes, the simplex space is too high
+#' dimensional to easily visualize.
+#'
+#' @param fit A fitted \code{gensvm} object
+#' @param x the dataset to plot
+#' @param y.true the true data labels. If provided the objects will be colored
+#' using the true labels instead of the predicted labels. This makes it easy to
+#' identify misclassified objects.
+#' @param with.margins plot the margins
+#' @param with.shading show shaded areas for the class regions
+#' @param with.legend show the legend for the class labels
+#' @param center.plot ensure that the boundaries and margins are always visible
+#' in the plot
+#' @param ... further arguments are ignored
+#'
+#' @return returns the object passed as input
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @method plot gensvm
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # train the model
+#' fit <- gensvm(x, y)
+#'
+#' # plot the simplex space
+#' plot(fit, x)
+#'
+#' # plot and use the true colors (easier to spot misclassified samples)
+#' plot(fit, x, y.true=y)
+#'
+#' # plot only misclassified samples
+#' x.mis <- x[predict(fit, x) != y, ]
+#' y.mis.true <- y[predict(fit, x) != y, ]
+#' plot(fit, x.bad)
+#' plot(fit, x.bad, y.true=y.mis.true)
+#'
+plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
+ with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
+ ...)
+{
+ if (fit$n.classes != 3) {
+ cat("Error: Can only plot with 3 classes\n")
+ return
+ }
+
+ # Sanity check
+ if (ncol(x) != fit$n.features) {
+ cat("Error: Number of features of fitted model and testing data
+ disagree.\n")
+ return
+ }
+
+ x.train <- fit$X.train
+ if (fit$kernel != 'linear' && is.null(x.train)) {
+ cat("Error: The training data is needed to plot data for ",
+ "nonlinear GenSVM. This data is not present in the fitted ",
+ "model!\n", sep="")
+ return
+ }
+ if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
+ cat("Error: Number of features of fitted model and training data disagree.")
+ return
+ }
+
+ x <- as.matrix(x)
+
+ if (fit$kernel == 'linear') {
+ V <- coef(fit)
+ Z <- cbind(matrix(1, dim(x)[1], 1), x)
+ S <- Z %*% V
+ y.pred.orig <- predict(fit, x)
+ } else {
+ kernels <- c("linear", "poly", "rbf", "sigmoid")
+ kernel.idx <- which(kernels == fit$kernel) - 1
+ plotdata <- .Call("R_gensvm_plotdata_kernels",
+ as.matrix(x),
+ as.matrix(x.train),
+ as.matrix(fit$V),
+ as.integer(nrow(fit$V)),
+ as.integer(ncol(fit$V)),
+ as.integer(nrow(x.train)),
+ as.integer(nrow(x)),
+ as.integer(fit$n.features),
+ as.integer(fit$n.classes),
+ as.integer(kernel.idx),
+ fit$gamma,
+ fit$coef,
+ fit$degree,
+ fit$kernel.eigen.cutoff
+ )
+ S <- plotdata$ZV
+ y.pred.orig <- plotdata$y.pred
+ }
+
+ classes <- fit$classes
+ if (is.factor(y.pred.orig)) {
+ y.pred <- match(y.pred.orig, classes)
+ } else {
+ y.pred <- y.pred.orig
+ }
+
+ # Define some colors
+ point.blue <- rgb(31, 119, 180, maxColorValue=255)
+ point.orange <- rgb(255, 127, 14, maxColorValue=255)
+ point.green <- rgb(44, 160, 44, maxColorValue=255)
+ fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
+ fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
+ fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)
+
+ colors <- as.matrix(c(point.green, point.blue, point.orange))
+ markers <- as.matrix(c(15, 16, 17))
+
+ if (is.null(y.true)) {
+ col.vector <- colors[y.pred]
+ mark.vector <- markers[y.pred]
+ } else {
+ col.vector <- colors[y.true]
+ mark.vector <- markers[y.true]
+ }
+
+ par(pty="s")
+ if (center.plot) {
+ new.xlim <- c(min(min(S[, 1]), -1.2), max(max(S[, 1]), 1.2))
+ new.ylim <- c(min(min(S[, 2]), -0.75), max(max(S[, 2]), 1.2))
+ plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='',
+ asp=1, xlim=new.xlim, ylim=new.ylim)
+ } else {
+ plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='',
+ asp=1)
+ }
+
+ limits <- par("usr")
+ xmin <- limits[1]
+ xmax <- limits[2]
+ ymin <- limits[3]
+ ymax <- limits[4]
+
+ # draw the fixed boundaries
+ segments(0, 0, 0, ymin)
+ segments(0, 0, xmax, xmax/sqrt(3))
+ segments(xmin, abs(xmin)/sqrt(3), 0, 0)
+
+ if (with.margins) {
+ # margin from left below decision boundary to center
+ segments(xmin, -xmin/sqrt(3) - sqrt(4/3), -1, -1/sqrt(3), lty=2)
+
+ # margin from left center to down
+ segments(-1, -1/sqrt(3), -1, ymin, lty=2)
+
+ # margin from right center to middle
+ segments(1, -1/sqrt(3), 1, ymin, lty=2)
+
+ # margin from right center to right boundary
+ segments(1, -1/sqrt(3), xmax, xmax/sqrt(3) - sqrt(4/3), lty=2)
+
+ # margin from center to top left
+ segments(xmin, -xmin/sqrt(3) + sqrt(4/3), 0, sqrt(4/3), lty=2)
+
+ # margin from center to top right
+ segments(0, sqrt(4/3), xmax, xmax/sqrt(3) + sqrt(4/3), lty=2)
+ }
+
+ if (with.shading) {
+ # bottom left
+ polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) -
+ sqrt(4/3)), col=fill.green, border=NA)
+ # bottom right
+ polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3),
+ -1/sqrt(3)), col=fill.blue, border=NA)
+ # top
+ polygon(c(xmin, 0, xmax, xmax, xmin),
+ c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3),
+ ymax, ymax), col=fill.orange,
+ border=NA)
+ }
+
+ if (with.legend) {
+ offset <- abs(xmax - xmin) * 0.05
+ legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
+ }
+
+ invisible(fit)
+}
diff --git a/R/plot.gensvm.grid.R b/R/plot.gensvm.grid.R
new file mode 100644
index 0000000..da101e6
--- /dev/null
+++ b/R/plot.gensvm.grid.R
@@ -0,0 +1,39 @@
+#' @title Plot the simplex space of the best fitted model in the GenSVMGrid
+#'
+#' @description This is a wrapper which calls the plot function for the best
+#' model in the provided GenSVMGrid object. See the documentation for
+#' \code{\link{plot.gensvm}} for more information.
+#'
+#' @param grid A \code{gensvm.grid} object trained with refit=TRUE
+#' @param x the dataset to plot
+#' @param ... further arguments are passed to the plot function
+#'
+#' @return returns the object passed as input
+#'
+#' @export
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' grid <- gensvm.grid(x, y)
+#' plot(grid, x)
+#'
+plot.gensvm.grid <- function(grid, x, ...)
+{
+ if (is.null(grid$best.estimator)) {
+ cat("Error: Can't plot, the best.estimator element is NULL\n")
+ return
+ }
+ fit <- grid$best.estimator
+ return(plot(fit, x, ...))
+}
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 5c8f2e7..7e04fe4 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -4,8 +4,8 @@
#' fitted GenSVM model.
#'
#' @param fit Fitted \code{gensvm} object
-#' @param newx Matrix of new values for \code{x} for which predictions need to
-#' be made.
+#' @param x.test Matrix of new values for \code{x} for which predictions need
+#' to be made.
#' @param \dots further arguments are ignored
#'
#' @return a vector of class labels, with the same type as the original class
@@ -15,7 +15,7 @@
#' @aliases predict
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -24,10 +24,20 @@
#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
#'
+#' # create a training and test sample
+#' attach(gensvm.train.test.split(x, y))
+#' fit <- gensvm(x.train, y.train)
#'
+#' # predict the class labels of the test sample
+#' y.test.pred <- predict(fit, x.test)
#'
-predict.gensvm <- function(fit, newx, ...)
+#' # compute the accuracy with gensvm.accuracy
+#' gensvm.accuracy(y.test, y.test.pred)
+#'
+predict.gensvm <- function(fit, x.test, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -37,16 +47,53 @@ predict.gensvm <- function(fit, newx, ...)
## the C implementation is *much* faster than doing it in R.
# Sanity check
- if (ncol(newx) != fit$n.features)
- stop("Number of features of fitted model and supplied data disagree.")
+ if (ncol(x.test) != fit$n.features) {
+ cat("Error: Number of features of fitted model and testing",
+ "data disagree.\n")
+ return
+ }
+
+ x.train <- fit$X.train
+ if (fit$kernel != 'linear' && is.null(x.train)) {
+ cat("Error: The training data is needed to compute predictions for ",
+ "nonlinear GenSVM. This data is not present in the fitted ",
+ "model!\n", sep="")
+ }
+ if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
+ cat("Error: Number of features of fitted model and training",
+ "data disagree.\n")
+ return
+ }
- y.pred.c <- .Call("R_gensvm_predict",
- as.matrix(newx),
+ if (fit$kernel == 'linear') {
+ y.pred.c <- .Call("R_gensvm_predict",
+ as.matrix(x.test),
as.matrix(fit$V),
- as.integer(nrow(newx)),
- as.integer(ncol(newx)),
+ as.integer(nrow(x.test)),
+ as.integer(ncol(x.test)),
as.integer(fit$n.classes)
)
+ } else {
+ kernels <- c("linear", "poly", "rbf", "sigmoid")
+ kernel.idx <- which(kernels == fit$kernel) - 1
+ y.pred.c <- .Call("R_gensvm_predict_kernels",
+ as.matrix(x.test),
+ as.matrix(x.train),
+ as.matrix(fit$V),
+ as.integer(nrow(fit$V)),
+ as.integer(ncol(fit$V)),
+ as.integer(nrow(x.train)),
+ as.integer(nrow(x.test)),
+ as.integer(fit$n.features),
+ as.integer(fit$n.classes),
+ as.integer(kernel.idx),
+ fit$gamma,
+ fit$coef,
+ fit$degree,
+ fit$kernel.eigen.cutoff
+ )
+ }
+
yhat <- fit$classes[y.pred.c]
return(yhat)
diff --git a/R/predict.gensvm.grid.R b/R/predict.gensvm.grid.R
new file mode 100644
index 0000000..81a0207
--- /dev/null
+++ b/R/predict.gensvm.grid.R
@@ -0,0 +1,47 @@
+#' @title Predict class labels from the GenSVMGrid class
+#'
+#' @description Predict class labels using the best model from a grid search.
+#' After doing a grid search with the \code{\link{gensvm.grid}} function, this
+#' function can be used to make predictions of class labels. It uses the best
+#' GenSVM model found during the grid search to do the predictions. Note that
+#' this model is only available if \code{refit=TRUE} was specified in the
+#' \code{\link{gensvm.grid}} call (the default).
+#'
+#' @param grid A \code{gensvm.grid} object trained with \code{refit=TRUE}
+#' @param newx Matrix of new values for \code{x} for which predictions need to
+#' be computed.
+#' @param \dots further arguments are passed to predict.gensvm()
+#'
+#' @return a vector of class labels, with the same type as the original class
+#' labels provided to gensvm.grid()
+#'
+#' @export
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # run a grid search
+#' grid <- gensvm.grid(x, y)
+#'
+#' # predict training sample
+#' y.hat <- predict(grid, x)
+#'
+predict.gensvm.grid <- function(grid, newx, ...)
+{
+ if (is.null(grid$best.estimator)) {
+ cat("Error: Can't predict, the best.estimator element is NULL\n")
+ return
+ }
+
+ return(predict(grid$best.estimator, newx, ...))
+}
diff --git a/R/print.gensvm.R b/R/print.gensvm.R
index 06a3649..119b264 100644
--- a/R/print.gensvm.R
+++ b/R/print.gensvm.R
@@ -2,13 +2,14 @@
#'
#' @description Prints a short description of the fitted GenSVM model
#'
-#' @param object A \code{gensvm} object to print
+#' @param fit A \code{gensvm} object to print
#' @param \dots further arguments are ignored
#'
-#' @return returns the object passed as input
+#' @return returns the object passed as input. This can be useful for chaining
+#' operations on a fit object.
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -20,37 +21,52 @@
#' @export
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
#'
+#' # fit and print the model
+#' fit <- gensvm(x, y)
+#' print(fit)
#'
-print.gensvm <- function(object, ...)
+#' # (advanced) use the fact that print returns the fitted model
+#' fit <- gensvm(x, y)
+#' predict(print(fit), x)
+#'
+print.gensvm <- function(fit, ...)
{
- cat("\nCall:\n")
- dput(object$call)
- cat("\nData:\n")
- cat("\tn.objects:", object$n.objects, "\n")
- cat("\tn.features:", object$n.features, "\n")
- cat("\tn.classes:", object$n.classes, "\n")
- cat("\tclasses:", object$classes, "\n")
+ cat("Data:\n")
+ cat("\tn.objects:", fit$n.objects, "\n")
+ cat("\tn.features:", fit$n.features, "\n")
+ cat("\tn.classes:", fit$n.classes, "\n")
+ if (is.factor(fit$classes))
+ cat("\tclasses:", levels(fit$classes), "\n")
+ else
+ cat("\tclasses:", fit$classes, "\n")
cat("Parameters:\n")
- cat("\tp:", object$p, "\n")
- cat("\tlambda:", object$lambda, "\n")
- cat("\tkappa:", object$kappa, "\n")
- cat("\tepsilon:", object$epsilon, "\n")
- cat("\tweights:", object$weights, "\n")
- cat("\tmax.iter:", object$max.iter, "\n")
- cat("\trandom.seed:", object$random.seed, "\n")
- cat("\tkernel:", object$kernel, "\n")
- if (object$kernel %in% c("poly", "rbf", "sigmoid")) {
- cat("\tkernel.eigen.cutoff:", object$kernel.eigen.cutoff, "\n")
- cat("\tgamma:", object$gamma, "\n")
+ cat("\tp:", fit$p, "\n")
+ cat("\tlambda:", fit$lambda, "\n")
+ cat("\tkappa:", fit$kappa, "\n")
+ cat("\tepsilon:", fit$epsilon, "\n")
+ cat("\tweights:", fit$weights, "\n")
+ cat("\tmax.iter:", fit$max.iter, "\n")
+ cat("\trandom.seed:", fit$random.seed, "\n")
+ if (is.factor(fit$kernel)) {
+ cat("\tkernel:", levels(fit$kernel)[as.numeric(fit$kernel)], "\n")
+ } else {
+ cat("\tkernel:", fit$kernel, "\n")
+ }
+ if (fit$kernel %in% c("poly", "rbf", "sigmoid")) {
+ cat("\tkernel.eigen.cutoff:", fit$kernel.eigen.cutoff, "\n")
+ cat("\tgamma:", fit$gamma, "\n")
}
- if (object$kernel %in% c("poly", "sigmoid"))
- cat("\tcoef:", object$coef, "\n")
- if (object$kernel == 'poly')
- cat("\tdegree:", object$degree, "\n")
+ if (fit$kernel %in% c("poly", "sigmoid"))
+ cat("\tcoef:", fit$coef, "\n")
+ if (fit$kernel == 'poly')
+ cat("\tdegree:", fit$degree, "\n")
cat("Results:\n")
- cat("\tn.iter:", object$n.iter, "\n")
- cat("\tn.support:", object$n.support, "\n")
+ cat("\ttime:", fit$training.time, "\n")
+ cat("\tn.iter:", fit$n.iter, "\n")
+ cat("\tn.support:", fit$n.support, "\n")
- invisible(object)
+ invisible(fit)
}
diff --git a/R/print.gensvm.grid.R b/R/print.gensvm.grid.R
new file mode 100644
index 0000000..88967d7
--- /dev/null
+++ b/R/print.gensvm.grid.R
@@ -0,0 +1,61 @@
+#' @title Print the fitted GenSVMGrid model
+#'
+#' @description Prints the summary of the fitted GenSVMGrid model
+#'
+#' @param grid a \code{gensvm.grid} object to print
+#' @param \dots further arguments are ignored
+#'
+#' @return returns the object passed as input
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @method print gensvm.grid
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # fit a grid search and print the resulting object
+#' grid <- gensvm.grid(x, y)
+#' print(grid)
+#'
+print.gensvm.grid <- function(grid, ...)
+{
+ cat("Data:\n")
+ cat("\tn.objects:", grid$n.objects, "\n")
+ cat("\tn.features:", grid$n.features, "\n")
+ cat("\tn.classes:", grid$n.classes, "\n")
+ if (is.factor(grid$classes))
+ cat("\tclasses:", levels(grid$classes), "\n")
+ else
+ cat("\tclasses:", grid$classes, "\n")
+ cat("Config:\n")
+ cat("\tNumber of cv splits:", grid$n.splits, "\n")
+ not.run <- sum(is.na(grid$cv.results$rank.test.score))
+ if (not.run > 0) {
+ cat("\tParameter grid size:", dim(grid$param.grid)[1])
+ cat(" (", not.run, " incomplete)", sep="")
+ cat("\n")
+ } else {
+ cat("\tParameter grid size:", dim(grid$param.grid)[1], "\n")
+ }
+ cat("Results:\n")
+ cat("\tTotal grid search time:", grid$total.time, "\n")
+ if (!is.na(grid$best.index)) {
+ best <- grid$cv.results[grid$best.index, ]
+ cat("\tBest mean test score:", best$mean.test.score, "\n")
+ cat("\tBest mean fit time:", best$mean.fit.time, "\n")
+ for (name in colnames(grid$best.params))
+ cat("\tBest parameter", name, "=", grid$best.params[[name]], "\n")
+ }
+
+ invisible(grid)
+}