aboutsummaryrefslogtreecommitdiff
path: root/R/predict.gensvm.R
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-02-23 17:10:16 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-02-23 17:10:16 +0000
commitbdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468 (patch)
treedcf17ae3e5fef6cee367f0da985bb4894495aee9 /R/predict.gensvm.R
parentupdate gensvm C library (diff)
downloadrgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.tar.gz
rgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.zip
Implement fitting and prediction
Diffstat (limited to 'R/predict.gensvm.R')
-rw-r--r--R/predict.gensvm.R25
1 files changed, 21 insertions, 4 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 6cc8851..5c8f2e7 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -3,7 +3,7 @@
#' @description This function predicts the class labels of new data using a
#' fitted GenSVM model.
#'
-#' @param object Fitted \code{gensvm} object
+#' @param fit Fitted \code{gensvm} object
#' @param newx Matrix of new values for \code{x} for which predictions need to
#' be made.
#' @param \dots further arguments are ignored
@@ -27,10 +27,27 @@
#'
#'
#'
-predict.gensvm <- function(object, newx, ...)
+predict.gensvm <- function(fit, newx, ...)
{
- # TODO: C library fitting prediction here (or not? with the column-major
- # order it may be faster to do it directly in R)
+ ## Implementation note:
+ ## - It might seem that it would be faster to do the prediction directly in
+ ## R here, since we then don't need to switch to C, construct model and
+ ## data structures, copy the data, etc. before doing the prediction.
+ ## However, if you actually implement it and compare, we find that calling
+ ## the C implementation is *much* faster than doing it in R.
+
+ # Sanity check
+ if (ncol(newx) != fit$n.features)
+ stop("Number of features of fitted model and supplied data disagree.")
+
+ y.pred.c <- .Call("R_gensvm_predict",
+ as.matrix(newx),
+ as.matrix(fit$V),
+ as.integer(nrow(newx)),
+ as.integer(ncol(newx)),
+ as.integer(fit$n.classes)
+ )
+ yhat <- fit$classes[y.pred.c]
return(yhat)
}