diff options
Diffstat (limited to 'R/predict.gensvm.R')
| -rw-r--r-- | R/predict.gensvm.R | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 6cc8851..5c8f2e7 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -3,7 +3,7 @@ #' @description This function predicts the class labels of new data using a #' fitted GenSVM model. #' -#' @param object Fitted \code{gensvm} object +#' @param fit Fitted \code{gensvm} object #' @param newx Matrix of new values for \code{x} for which predictions need to #' be made. #' @param \dots further arguments are ignored @@ -27,10 +27,27 @@ #' #' #' -predict.gensvm <- function(object, newx, ...) +predict.gensvm <- function(fit, newx, ...) { - # TODO: C library fitting prediction here (or not? with the column-major - # order it may be faster to do it directly in R) + ## Implementation note: + ## - It might seem that it would be faster to do the prediction directly in + ## R here, since we then don't need to switch to C, construct model and + ## data structures, copy the data, etc. before doing the prediction. + ## However, if you actually implement it and compare, we find that calling + ## the C implementation is *much* faster than doing it in R. + + # Sanity check + if (ncol(newx) != fit$n.features) + stop("Number of features of fitted model and supplied data disagree.") + + y.pred.c <- .Call("R_gensvm_predict", + as.matrix(newx), + as.matrix(fit$V), + as.integer(nrow(newx)), + as.integer(ncol(newx)), + as.integer(fit$n.classes) + ) + yhat <- fit$classes[y.pred.c] return(yhat) } |
