aboutsummaryrefslogtreecommitdiff
path: root/R/predict.gensvm.R
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-04 15:06:33 -0400
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-04 15:06:33 -0400
commitde17a6d755e9369a91abdb06562ee93d7b323bbd (patch)
treecbede010da1046ee6627143aa0e41d8ec0fbb09e /R/predict.gensvm.R
parentAdd importFrom statements (diff)
downloadrgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.tar.gz
rgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.zip
Adhere to generic function signatures
Diffstat (limited to 'R/predict.gensvm.R')
-rw-r--r--R/predict.gensvm.R38
1 files changed, 19 insertions, 19 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 0b5fa91..43a0d52 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -3,7 +3,7 @@
#' @description This function predicts the class labels of new data using a
#' fitted GenSVM model.
#'
-#' @param fit Fitted \code{gensvm} object
+#' @param object Fitted \code{gensvm} object
#' @param newdata Matrix of new data for which predictions need to be made.
#' @param add.rownames add the rownames from the training data to the
#' predictions
@@ -45,7 +45,7 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y.test, y.test.pred)
#'
-predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
+predict.gensvm <- function(object, newdata, add.rownames=FALSE, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -55,48 +55,48 @@ predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
## the C implementation is *much* faster than doing it in R.
if (missing(newdata)) {
- newdata <- eval.parent(fit$call$x)
+ newdata <- eval.parent(object$call$x)
}
x.test <- as.matrix(newdata)
# Sanity check
- if (ncol(x.test) != fit$n.features) {
+ if (ncol(x.test) != object$n.features) {
cat("Error: Number of features of fitted model and testing",
"data disagree.\n")
return(invisible(NULL))
}
- x.train <- eval.parent(fit$call$x)
- if (fit$kernel == 'linear') {
+ x.train <- eval.parent(object$call$x)
+ if (object$kernel == 'linear') {
y.pred.c <- .Call("R_gensvm_predict",
as.matrix(x.test),
- as.matrix(fit$V),
+ as.matrix(object$V),
as.integer(nrow(x.test)),
as.integer(ncol(x.test)),
- as.integer(fit$n.classes)
+ as.integer(object$n.classes)
)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
- kernel.idx <- which(kernels == fit$kernel) - 1
+ kernel.idx <- which(kernels == object$kernel) - 1
y.pred.c <- .Call("R_gensvm_predict_kernels",
as.matrix(x.test),
as.matrix(x.train),
- as.matrix(fit$V),
- as.integer(nrow(fit$V)),
- as.integer(ncol(fit$V)),
+ as.matrix(object$V),
+ as.integer(nrow(object$V)),
+ as.integer(ncol(object$V)),
as.integer(nrow(x.train)),
as.integer(nrow(x.test)),
- as.integer(fit$n.features),
- as.integer(fit$n.classes),
+ as.integer(object$n.features),
+ as.integer(object$n.classes),
as.integer(kernel.idx),
- fit$gamma,
- fit$coef,
- fit$degree,
- fit$kernel.eigen.cutoff
+ object$gamma,
+ object$coef,
+ object$degree,
+ object$kernel.eigen.cutoff
)
}
- yhat <- fit$classes[y.pred.c]
+ yhat <- object$classes[y.pred.c]
if (add.rownames) {
yhat <- matrix(yhat, length(yhat), 1)