aboutsummaryrefslogtreecommitdiff
path: root/R/predict.gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/predict.gensvm.R')
-rw-r--r--R/predict.gensvm.R33
1 files changed, 16 insertions, 17 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 558f80b..2cf7401 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -4,8 +4,9 @@
#' fitted GenSVM model.
#'
#' @param fit Fitted \code{gensvm} object
-#' @param x.test Matrix of new values for \code{x} for which predictions need
-#' to be made.
+#' @param newdata Matrix of new data for which predictions need to be made.
+#' @param add.rownames add the rownames from the training data to the
+#' predictions
#' @param \dots further arguments are ignored
#'
#' @return a vector of class labels, with the same type as the original class
@@ -41,7 +42,7 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y.test, y.test.pred)
#'
-predict.gensvm <- function(fit, x.test, ...)
+predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -50,26 +51,19 @@ predict.gensvm <- function(fit, x.test, ...)
## However, if you actually implement it and compare, we find that calling
## the C implementation is *much* faster than doing it in R.
+ if (missing(newdata)) {
+ newdata <- eval.parent(fit$call$x)
+ }
+ x.test <- as.matrix(newdata)
+
# Sanity check
if (ncol(x.test) != fit$n.features) {
cat("Error: Number of features of fitted model and testing",
"data disagree.\n")
- invisible(NULL)
- }
-
- x.train <- fit$X.train
- if (fit$kernel != 'linear' && is.null(x.train)) {
- cat("Error: The training data is needed to compute predictions for ",
- "nonlinear GenSVM. This data is not present in the fitted ",
- "model!\n", sep="")
- invisible(NULL)
- }
- if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
- cat("Error: Number of features of fitted model and training",
- "data disagree.\n")
- invisible(NULL)
+ return(invisible(NULL))
}
+ x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
y.pred.c <- .Call("R_gensvm_predict",
as.matrix(x.test),
@@ -101,5 +95,10 @@ predict.gensvm <- function(fit, x.test, ...)
yhat <- fit$classes[y.pred.c]
+ if (add.rownames) {
+ yhat <- matrix(yhat, length(yhat), 1)
+ rownames(yhat) <- rownames(x.train)
+ }
+
return(yhat)
}