aboutsummaryrefslogtreecommitdiff
path: root/R/predict.gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/predict.gensvm.R')
-rw-r--r--R/predict.gensvm.R67
1 files changed, 57 insertions, 10 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 5c8f2e7..7e04fe4 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -4,8 +4,8 @@
#' fitted GenSVM model.
#'
#' @param fit Fitted \code{gensvm} object
-#' @param newx Matrix of new values for \code{x} for which predictions need to
-#' be made.
+#' @param x.test Matrix of new values for \code{x} for which predictions need
+#' to be made.
#' @param \dots further arguments are ignored
#'
#' @return a vector of class labels, with the same type as the original class
@@ -15,7 +15,7 @@
#' @aliases predict
#'
#' @author
-#' Gerrit J.J. van den Burg, Patrick J.F. Groenen
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
@@ -24,10 +24,20 @@
#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
#'
#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
#'
+#' # create a training and test sample
+#' attach(gensvm.train.test.split(x, y))
+#' fit <- gensvm(x.train, y.train)
#'
+#' # predict the class labels of the test sample
+#' y.test.pred <- predict(fit, x.test)
#'
-predict.gensvm <- function(fit, newx, ...)
+#' # compute the accuracy with gensvm.accuracy
+#' gensvm.accuracy(y.test, y.test.pred)
+#'
+predict.gensvm <- function(fit, x.test, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -37,16 +47,53 @@ predict.gensvm <- function(fit, newx, ...)
## the C implementation is *much* faster than doing it in R.
# Sanity check
- if (ncol(newx) != fit$n.features)
- stop("Number of features of fitted model and supplied data disagree.")
+ if (ncol(x.test) != fit$n.features) {
+ cat("Error: Number of features of fitted model and testing",
+ "data disagree.\n")
+ return
+ }
+
+ x.train <- fit$X.train
+ if (fit$kernel != 'linear' && is.null(x.train)) {
+ cat("Error: The training data is needed to compute predictions for ",
+ "nonlinear GenSVM. This data is not present in the fitted ",
+ "model!\n", sep="")
+ }
+ if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
+ cat("Error: Number of features of fitted model and training",
+ "data disagree.\n")
+ return
+ }
- y.pred.c <- .Call("R_gensvm_predict",
- as.matrix(newx),
+ if (fit$kernel == 'linear') {
+ y.pred.c <- .Call("R_gensvm_predict",
+ as.matrix(x.test),
as.matrix(fit$V),
- as.integer(nrow(newx)),
- as.integer(ncol(newx)),
+ as.integer(nrow(x.test)),
+ as.integer(ncol(x.test)),
as.integer(fit$n.classes)
)
+ } else {
+ kernels <- c("linear", "poly", "rbf", "sigmoid")
+ kernel.idx <- which(kernels == fit$kernel) - 1
+ y.pred.c <- .Call("R_gensvm_predict_kernels",
+ as.matrix(x.test),
+ as.matrix(x.train),
+ as.matrix(fit$V),
+ as.integer(nrow(fit$V)),
+ as.integer(ncol(fit$V)),
+ as.integer(nrow(x.train)),
+ as.integer(nrow(x.test)),
+ as.integer(fit$n.features),
+ as.integer(fit$n.classes),
+ as.integer(kernel.idx),
+ fit$gamma,
+ fit$coef,
+ fit$degree,
+ fit$kernel.eigen.cutoff
+ )
+ }
+
yhat <- fit$classes[y.pred.c]
return(yhat)