diff options
Diffstat (limited to 'R/predict.gensvm.R')
| -rw-r--r-- | R/predict.gensvm.R | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 558f80b..2cf7401 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -4,8 +4,9 @@ #' fitted GenSVM model. #' #' @param fit Fitted \code{gensvm} object -#' @param x.test Matrix of new values for \code{x} for which predictions need -#' to be made. +#' @param newdata Matrix of new data for which predictions need to be made. +#' @param add.rownames add the rownames from the training data to the +#' predictions #' @param \dots further arguments are ignored #' #' @return a vector of class labels, with the same type as the original class @@ -41,7 +42,7 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y.test, y.test.pred) #' -predict.gensvm <- function(fit, x.test, ...) +predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -50,26 +51,19 @@ predict.gensvm <- function(fit, x.test, ...) ## However, if you actually implement it and compare, we find that calling ## the C implementation is *much* faster than doing it in R. + if (missing(newdata)) { + newdata <- eval.parent(fit$call$x) + } + x.test <- as.matrix(newdata) + # Sanity check if (ncol(x.test) != fit$n.features) { cat("Error: Number of features of fitted model and testing", "data disagree.\n") - invisible(NULL) - } - - x.train <- fit$X.train - if (fit$kernel != 'linear' && is.null(x.train)) { - cat("Error: The training data is needed to compute predictions for ", - "nonlinear GenSVM. This data is not present in the fitted ", - "model!\n", sep="") - invisible(NULL) - } - if (!is.null(x.train) && ncol(x.train) != fit$n.features) { - cat("Error: Number of features of fitted model and training", - "data disagree.\n") - invisible(NULL) + return(invisible(NULL)) } + x.train <- eval.parent(fit$call$x) if (fit$kernel == 'linear') { y.pred.c <- .Call("R_gensvm_predict", as.matrix(x.test), @@ -101,5 +95,10 @@ predict.gensvm <- function(fit, x.test, ...) yhat <- fit$classes[y.pred.c] + if (add.rownames) { + yhat <- matrix(yhat, length(yhat), 1) + rownames(yhat) <- rownames(x.train) + } + return(yhat) } |
