diff options
Diffstat (limited to 'R/predict.gensvm.R')
| -rw-r--r-- | R/predict.gensvm.R | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 0b5fa91..43a0d52 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -3,7 +3,7 @@ #' @description This function predicts the class labels of new data using a #' fitted GenSVM model. #' -#' @param fit Fitted \code{gensvm} object +#' @param object Fitted \code{gensvm} object #' @param newdata Matrix of new data for which predictions need to be made. #' @param add.rownames add the rownames from the training data to the #' predictions @@ -45,7 +45,7 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y.test, y.test.pred) #' -predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) +predict.gensvm <- function(object, newdata, add.rownames=FALSE, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -55,48 +55,48 @@ predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) ## the C implementation is *much* faster than doing it in R. if (missing(newdata)) { - newdata <- eval.parent(fit$call$x) + newdata <- eval.parent(object$call$x) } x.test <- as.matrix(newdata) # Sanity check - if (ncol(x.test) != fit$n.features) { + if (ncol(x.test) != object$n.features) { cat("Error: Number of features of fitted model and testing", "data disagree.\n") return(invisible(NULL)) } - x.train <- eval.parent(fit$call$x) - if (fit$kernel == 'linear') { + x.train <- eval.parent(object$call$x) + if (object$kernel == 'linear') { y.pred.c <- .Call("R_gensvm_predict", as.matrix(x.test), - as.matrix(fit$V), + as.matrix(object$V), as.integer(nrow(x.test)), as.integer(ncol(x.test)), - as.integer(fit$n.classes) + as.integer(object$n.classes) ) } else { kernels <- c("linear", "poly", "rbf", "sigmoid") - kernel.idx <- which(kernels == fit$kernel) - 1 + kernel.idx <- which(kernels == object$kernel) - 1 y.pred.c <- .Call("R_gensvm_predict_kernels", as.matrix(x.test), as.matrix(x.train), - as.matrix(fit$V), - as.integer(nrow(fit$V)), - as.integer(ncol(fit$V)), + as.matrix(object$V), + as.integer(nrow(object$V)), + as.integer(ncol(object$V)), as.integer(nrow(x.train)), as.integer(nrow(x.test)), - as.integer(fit$n.features), - as.integer(fit$n.classes), + as.integer(object$n.features), + as.integer(object$n.classes), as.integer(kernel.idx), - fit$gamma, - fit$coef, - fit$degree, - fit$kernel.eigen.cutoff + object$gamma, + object$coef, + object$degree, + object$kernel.eigen.cutoff ) } - yhat <- fit$classes[y.pred.c] + yhat <- object$classes[y.pred.c] if (add.rownames) { yhat <- matrix(yhat, length(yhat), 1) |
