diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
| commit | 004941896bac692d354c41a3334d20ee1d4627f7 (patch) | |
| tree | 2b11e42d8524843409e2bf8deb4ceb74c8b69347 /R/gensvm.train.test.split.R | |
| parent | updates to GenSVM C library (diff) | |
| download | rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.tar.gz rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.zip | |
GenSVM R package
Diffstat (limited to 'R/gensvm.train.test.split.R')
| -rw-r--r-- | R/gensvm.train.test.split.R | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/R/gensvm.train.test.split.R b/R/gensvm.train.test.split.R new file mode 100644 index 0000000..406f80e --- /dev/null +++ b/R/gensvm.train.test.split.R @@ -0,0 +1,121 @@ +#' @title Create a train/test split of a dataset +#' +#' @description Often it is desirable to split a dataset into a training and +#' testing sample. This function is included in GenSVM to make it easy to do +#' so. The function is inspired by a similar function in Scikit-Learn. +#' +#' @param x array to split +#' @param y another array to split (typically this is a vector) +#' @param train.size size of the training dataset. This can be provided as +#' float or as int. If it's a float, it should be between 0.0 and 1.0 and +#' represents the fraction of the dataset that should be placed in the training +#' dataset. If it's an int, it represents the exact number of samples in the +#' training dataset. If it is NULL, the complement of \code{test.size} will be +#' used. +#' @param test.size size of the test dataset. Similarly to train.size both a +#' float or an int can be supplied. If it's NULL, the complement of train.size +#' will be used. If both train.size and test.size are NULL, a default test.size +#' of 0.25 will be used. +#' @param shuffle shuffle the rows or not +#' @param random.state seed for the random number generator (int) +#' +#' @author +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr +#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> +#' +#' @references +#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized +#' Multiclass Support Vector Machine}, Journal of Machine Learning Research, +#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. +#' +#' @export +#' +#' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] +#' +#' # using the default values +#' split <- gensvm.train.test.split(x, y) +#' +#' # using the split in a GenSVM model +#' fit <- gensvm(split$x.train, split$y.train) +#' gensvm.accuracy(split$y.test, predict(fit, split$x.test)) +#' +#' # using attach makes the results directly available +#' attach(gensvm.train.test.split(x, y)) +#' fit <- gensvm(x.train, y.train) +#' gensvm.accuracy(y.test, predict(fit, x.test)) +#' +gensvm.train.test.split <- function(x, y=NULL, train.size=NULL, test.size=NULL, + shuffle=TRUE, random.state=NULL, + return.idx=FALSE) +{ + if (!is.null(y) && dim(as.matrix(x))[1] != dim(as.matrix(y))[1]) { + cat("Error: First dimension of x and y should be equal.\n") + return + } + + n.objects <- dim(as.matrix(x))[1] + + if (is.null(train.size) && is.null(test.size)) { + test.size <- round(0.25 * n.objects) + train.size <- n.objects - test.size + } + else if (is.null(train.size)) { + if (test.size > 0.0 && test.size < 1.0) + test.size <- round(n.objects * test.size) + train.size <- n.objects - test.size + } + else if (is.null(test.size)) { + if (train.size > 0.0 && train.size < 1.0) + train.size <- round(n.objects * train.size) + test.size <- n.objects - train.size + } + else { + if (train.size > 0.0 && train.size < 1.0) + train.size <- round(n.objects * train.size) + if (test.size > 0.0 && test.size < 1.0) + test.size <- round(n.objects * test.size) + } + + if (!is.null(random.state)) + set.seed(random.state) + + if (shuffle) { + train.idx <- sample(n.objects, train.size) + diff <- setdiff(1:n.objects, train.idx) + test.idx <- sample(diff, test.size) + } else { + train.idx <- 1:train.size + diff <- setdiff(1:n.objects, train.idx) + test.idx <- diff[1:test.size] + } + + x.train <- x[train.idx, ] + x.test <- x[test.idx, ] + + if (!is.null(y)) { + if (is.matrix(y)) { + y.train <- y[train.idx, ] + y.test <- y[test.idx, ] + } else { + y.train <- y[train.idx] + y.test <- y[test.idx] + } + } + + out <- list( + x.train = x.train, + x.test = x.test + ) + if (!is.null(y)) { + out$y.train <- y.train + out$y.test <- y.test + } + if (return.idx) { + out$idx.train <- train.idx + out$idx.test <- test.idx + } + + return(out) +} |
