aboutsummaryrefslogtreecommitdiff
path: root/R/plot.gensvm.R
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 22:07:11 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 22:07:11 +0100
commit9cbb676219df92e5600a1a19fe8d63ced10e1b28 (patch)
treedc47284c86d7d93f71b1c888c724da98295525b0 /R/plot.gensvm.R
parentReturn invisibly on error (diff)
downloadrgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.tar.gz
rgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.zip
Fixes to get the input data from the call
Diffstat (limited to 'R/plot.gensvm.R')
-rw-r--r--R/plot.gensvm.R66
1 files changed, 25 insertions, 41 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index ac40ccf..d0c1eb2 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -6,10 +6,9 @@
#' dimensional to easily visualize.
#'
#' @param fit A fitted \code{gensvm} object
-#' @param x the dataset to plot
-#' @param y.true the true data labels. If provided the objects will be colored
-#' using the true labels instead of the predicted labels. This makes it easy to
-#' identify misclassified objects.
+#' @param x the dataset to plot (if NULL the training data is used)
+#' @param y the labels to color points with (if NULL the predicted labels are
+#' used)
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
@@ -48,19 +47,19 @@
#' fit <- gensvm(x, y)
#'
#' # plot the simplex space
-#' plot(fit, x)
+#' plot(fit)
#'
#' # plot and use the true colors (easier to spot misclassified samples)
-#' plot(fit, x, y.true=y)
+#' plot(fit, y)
#'
#' # plot only misclassified samples
-#' x.mis <- x[predict(fit, x) != y, ]
-#' y.mis.true <- y[predict(fit, x) != y]
-#' plot(fit, x.mis)
-#' plot(fit, x.mis, y.true=y.mis.true)
+#' x.mis <- x[predict(fit) != y, ]
+#' y.mis.true <- y[predict(fit) != y]
+#' plot(fit, x.test=x.mis)
+#' plot(fit, y.mis.true, x.test=x.mis)
#'
-plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
- with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
+plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
+ with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
if (!(fit$n.classes %in% c(2,3))) {
@@ -68,44 +67,32 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
return(invisible(NULL))
}
+ x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test
+
# Sanity check
- if (ncol(x) != fit$n.features) {
+ if (ncol(x.test) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
- invisible(NULL)
- }
-
- x.train <- fit$X.train
- if (fit$kernel != 'linear' && is.null(x.train)) {
- cat("Error: The training data is needed to plot data for ",
- "nonlinear GenSVM. This data is not present in the fitted ",
- "model!\n", sep="")
- invisible(NULL)
- }
- if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
- cat("Error: Number of features of fitted model and training data ",
- "disagree.\n", sep="")
- invisible(NULL)
+ return(invisible(NULL))
}
- x <- as.matrix(x)
-
+ x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
V <- coef(fit)
- Z <- cbind(matrix(1, dim(x)[1], 1), x)
+ Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test))
S <- Z %*% V
- y.pred.orig <- predict(fit, x)
+ y.pred.orig <- predict(fit, x.test)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
kernel.idx <- which(kernels == fit$kernel) - 1
plotdata <- .Call("R_gensvm_plotdata_kernels",
- as.matrix(x),
+ as.matrix(x.test),
as.matrix(x.train),
as.matrix(fit$V),
as.integer(nrow(fit$V)),
as.integer(ncol(fit$V)),
as.integer(nrow(x.train)),
- as.integer(nrow(x)),
+ as.integer(nrow(x.test)),
as.integer(fit$n.features),
as.integer(fit$n.classes),
as.integer(kernel.idx),
@@ -125,16 +112,13 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
y.pred <- y.pred.orig
}
- labels <- if(is.null(y.true)) y.pred else y.true
- classes <- unique(labels)
-
- colors <- gensvm.plot.colors(length(classes))
- markers <- gensvm.plot.markers(length(classes))
+ labels <- if(missing(y)) y.pred else match(y, classes)
- indices <- match(labels, classes)
+ colors <- gensvm.plot.colors(fit$n.classes)
+ markers <- gensvm.plot.markers(fit$n.classes)
- col.vector <- colors[indices]
- mark.vector <- markers[indices]
+ col.vector <- colors[labels]
+ mark.vector <- markers[labels]
if (fit$n.classes == 2)
S <- cbind(S, matrix(0, nrow(S), 1))