aboutsummaryrefslogtreecommitdiff
path: root/R/plot.gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/plot.gensvm.R')
-rw-r--r--R/plot.gensvm.R26
1 files changed, 14 insertions, 12 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index d490c1e..129d90a 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -5,10 +5,11 @@
#' 3 classes. For more than 3 classes, the simplex space is too high
#' dimensional to easily visualize.
#'
-#' @param fit A fitted \code{gensvm} object
-#' @param x the dataset to plot (if NULL the training data is used)
-#' @param y the labels to color points with (if NULL the predicted labels are
-#' used)
+#' @param x A fitted \code{gensvm} object
+#' @param labels the labels to color points with. If this is omitted the
+#' predicted labels are used.
+#' @param newdata the dataset to plot. If this is NULL the training data is
+#' used.
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
@@ -62,19 +63,20 @@
#' plot(fit, x.test=x.mis)
#' plot(fit, y.mis.true, x.test=x.mis)
#'
-plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
+plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE,
with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
+ fit <- x
if (!(fit$n.classes %in% c(2,3))) {
cat("Error: Can only plot with 2 or 3 classes\n")
return(invisible(NULL))
}
- x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test
+ newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata
# Sanity check
- if (ncol(x.test) != fit$n.features) {
+ if (ncol(newdata) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
return(invisible(NULL))
@@ -83,20 +85,20 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
V <- coef(fit)
- Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test))
+ Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata))
S <- Z %*% V
- y.pred.orig <- predict(fit, x.test)
+ y.pred.orig <- predict(fit, newdata)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
kernel.idx <- which(kernels == fit$kernel) - 1
plotdata <- .Call("R_gensvm_plotdata_kernels",
- as.matrix(x.test),
+ as.matrix(newdata),
as.matrix(x.train),
as.matrix(fit$V),
as.integer(nrow(fit$V)),
as.integer(ncol(fit$V)),
as.integer(nrow(x.train)),
- as.integer(nrow(x.test)),
+ as.integer(nrow(newdata)),
as.integer(fit$n.features),
as.integer(fit$n.classes),
as.integer(kernel.idx),
@@ -116,7 +118,7 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
y.pred <- y.pred.orig
}
- labels <- if(missing(y)) y.pred else match(y, classes)
+ labels <- if(missing(labels)) y.pred else match(labels, classes)
colors <- gensvm.plot.colors(fit$n.classes)
markers <- gensvm.plot.markers(fit$n.classes)