diff options
Diffstat (limited to 'R/plot.gensvm.R')
| -rw-r--r-- | R/plot.gensvm.R | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R index d490c1e..129d90a 100644 --- a/R/plot.gensvm.R +++ b/R/plot.gensvm.R @@ -5,10 +5,11 @@ #' 3 classes. For more than 3 classes, the simplex space is too high #' dimensional to easily visualize. #' -#' @param fit A fitted \code{gensvm} object -#' @param x the dataset to plot (if NULL the training data is used) -#' @param y the labels to color points with (if NULL the predicted labels are -#' used) +#' @param x A fitted \code{gensvm} object +#' @param labels the labels to color points with. If this is omitted the +#' predicted labels are used. +#' @param newdata the dataset to plot. If this is NULL the training data is +#' used. #' @param with.margins plot the margins #' @param with.shading show shaded areas for the class regions #' @param with.legend show the legend for the class labels @@ -62,19 +63,20 @@ #' plot(fit, x.test=x.mis) #' plot(fit, y.mis.true, x.test=x.mis) #' -plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, +plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE, with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, xlim=NULL, ylim=NULL, ...) { + fit <- x if (!(fit$n.classes %in% c(2,3))) { cat("Error: Can only plot with 2 or 3 classes\n") return(invisible(NULL)) } - x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test + newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata # Sanity check - if (ncol(x.test) != fit$n.features) { + if (ncol(newdata) != fit$n.features) { cat("Error: Number of features of fitted model and testing data disagree.\n") return(invisible(NULL)) @@ -83,20 +85,20 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, x.train <- eval.parent(fit$call$x) if (fit$kernel == 'linear') { V <- coef(fit) - Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test)) + Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata)) S <- Z %*% V - y.pred.orig <- predict(fit, x.test) + y.pred.orig <- predict(fit, newdata) } else { kernels <- c("linear", "poly", "rbf", "sigmoid") kernel.idx <- which(kernels == fit$kernel) - 1 plotdata <- .Call("R_gensvm_plotdata_kernels", - as.matrix(x.test), + as.matrix(newdata), as.matrix(x.train), as.matrix(fit$V), as.integer(nrow(fit$V)), as.integer(ncol(fit$V)), as.integer(nrow(x.train)), - as.integer(nrow(x.test)), + as.integer(nrow(newdata)), as.integer(fit$n.features), as.integer(fit$n.classes), as.integer(kernel.idx), @@ -116,7 +118,7 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, y.pred <- y.pred.orig } - labels <- if(missing(y)) y.pred else match(y, classes) + labels <- if(missing(labels)) y.pred else match(labels, classes) colors <- gensvm.plot.colors(fit$n.classes) markers <- gensvm.plot.markers(fit$n.classes) |
