diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-27 12:31:28 +0100 |
| commit | 004941896bac692d354c41a3334d20ee1d4627f7 (patch) | |
| tree | 2b11e42d8524843409e2bf8deb4ceb74c8b69347 /R/predict.gensvm.R | |
| parent | updates to GenSVM C library (diff) | |
| download | rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.tar.gz rgensvm-004941896bac692d354c41a3334d20ee1d4627f7.zip | |
GenSVM R package
Diffstat (limited to 'R/predict.gensvm.R')
| -rw-r--r-- | R/predict.gensvm.R | 67 |
1 files changed, 57 insertions, 10 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 5c8f2e7..7e04fe4 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -4,8 +4,8 @@ #' fitted GenSVM model. #' #' @param fit Fitted \code{gensvm} object -#' @param newx Matrix of new values for \code{x} for which predictions need to -#' be made. +#' @param x.test Matrix of new values for \code{x} for which predictions need +#' to be made. #' @param \dots further arguments are ignored #' #' @return a vector of class labels, with the same type as the original class @@ -15,7 +15,7 @@ #' @aliases predict #' #' @author -#' Gerrit J.J. van den Burg, Patrick J.F. Groenen +#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr #' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com> #' #' @references @@ -24,10 +24,20 @@ #' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}. #' #' @examples +#' x <- iris[, -5] +#' y <- iris[, 5] #' +#' # create a training and test sample +#' attach(gensvm.train.test.split(x, y)) +#' fit <- gensvm(x.train, y.train) #' +#' # predict the class labels of the test sample +#' y.test.pred <- predict(fit, x.test) #' -predict.gensvm <- function(fit, newx, ...) +#' # compute the accuracy with gensvm.accuracy +#' gensvm.accuracy(y.test, y.test.pred) +#' +predict.gensvm <- function(fit, x.test, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -37,16 +47,53 @@ predict.gensvm <- function(fit, newx, ...) ## the C implementation is *much* faster than doing it in R. # Sanity check - if (ncol(newx) != fit$n.features) - stop("Number of features of fitted model and supplied data disagree.") + if (ncol(x.test) != fit$n.features) { + cat("Error: Number of features of fitted model and testing", + "data disagree.\n") + return + } + + x.train <- fit$X.train + if (fit$kernel != 'linear' && is.null(x.train)) { + cat("Error: The training data is needed to compute predictions for ", + "nonlinear GenSVM. This data is not present in the fitted ", + "model!\n", sep="") + } + if (!is.null(x.train) && ncol(x.train) != fit$n.features) { + cat("Error: Number of features of fitted model and training", + "data disagree.\n") + return + } - y.pred.c <- .Call("R_gensvm_predict", - as.matrix(newx), + if (fit$kernel == 'linear') { + y.pred.c <- .Call("R_gensvm_predict", + as.matrix(x.test), as.matrix(fit$V), - as.integer(nrow(newx)), - as.integer(ncol(newx)), + as.integer(nrow(x.test)), + as.integer(ncol(x.test)), as.integer(fit$n.classes) ) + } else { + kernels <- c("linear", "poly", "rbf", "sigmoid") + kernel.idx <- which(kernels == fit$kernel) - 1 + y.pred.c <- .Call("R_gensvm_predict_kernels", + as.matrix(x.test), + as.matrix(x.train), + as.matrix(fit$V), + as.integer(nrow(fit$V)), + as.integer(ncol(fit$V)), + as.integer(nrow(x.train)), + as.integer(nrow(x.test)), + as.integer(fit$n.features), + as.integer(fit$n.classes), + as.integer(kernel.idx), + fit$gamma, + fit$coef, + fit$degree, + fit$kernel.eigen.cutoff + ) + } + yhat <- fit$classes[y.pred.c] return(yhat) |
