aboutsummaryrefslogtreecommitdiff
path: root/R/plot.gensvm.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/plot.gensvm.R')
-rw-r--r--R/plot.gensvm.R199
1 files changed, 199 insertions, 0 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
new file mode 100644
index 0000000..0ce215b
--- /dev/null
+++ b/R/plot.gensvm.R
@@ -0,0 +1,199 @@
+#' @title Plot the simplex space of the fitted GenSVM model
+#'
+#' @description This function creates a plot of the simplex space for a fitted
+#' GenSVM model and the given data set, as long as the dataset consists of only
+#' 3 classes. For more than 3 classes, the simplex space is too high
+#' dimensional to easily visualize.
+#'
+#' @param fit A fitted \code{gensvm} object
+#' @param x the dataset to plot
+#' @param y.true the true data labels. If provided the objects will be colored
+#' using the true labels instead of the predicted labels. This makes it easy to
+#' identify misclassified objects.
+#' @param with.margins plot the margins
+#' @param with.shading show shaded areas for the class regions
+#' @param with.legend show the legend for the class labels
+#' @param center.plot ensure that the boundaries and margins are always visible
+#' in the plot
+#' @param ... further arguments are ignored
+#'
+#' @return returns the object passed as input
+#'
+#' @author
+#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
+#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
+#'
+#' @references
+#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
+#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
+#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
+#'
+#' @method plot gensvm
+#' @export
+#'
+#' @examples
+#' x <- iris[, -5]
+#' y <- iris[, 5]
+#'
+#' # train the model
+#' fit <- gensvm(x, y)
+#'
+#' # plot the simplex space
+#' plot(fit, x)
+#'
+#' # plot and use the true colors (easier to spot misclassified samples)
+#' plot(fit, x, y.true=y)
+#'
+#' # plot only misclassified samples
+#' x.mis <- x[predict(fit, x) != y, ]
+#' y.mis.true <- y[predict(fit, x) != y, ]
+#' plot(fit, x.bad)
+#' plot(fit, x.bad, y.true=y.mis.true)
+#'
+plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
+ with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
+ ...)
+{
+ if (fit$n.classes != 3) {
+ cat("Error: Can only plot with 3 classes\n")
+ return
+ }
+
+ # Sanity check
+ if (ncol(x) != fit$n.features) {
+ cat("Error: Number of features of fitted model and testing data
+ disagree.\n")
+ return
+ }
+
+ x.train <- fit$X.train
+ if (fit$kernel != 'linear' && is.null(x.train)) {
+ cat("Error: The training data is needed to plot data for ",
+ "nonlinear GenSVM. This data is not present in the fitted ",
+ "model!\n", sep="")
+ return
+ }
+ if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
+ cat("Error: Number of features of fitted model and training data disagree.")
+ return
+ }
+
+ x <- as.matrix(x)
+
+ if (fit$kernel == 'linear') {
+ V <- coef(fit)
+ Z <- cbind(matrix(1, dim(x)[1], 1), x)
+ S <- Z %*% V
+ y.pred.orig <- predict(fit, x)
+ } else {
+ kernels <- c("linear", "poly", "rbf", "sigmoid")
+ kernel.idx <- which(kernels == fit$kernel) - 1
+ plotdata <- .Call("R_gensvm_plotdata_kernels",
+ as.matrix(x),
+ as.matrix(x.train),
+ as.matrix(fit$V),
+ as.integer(nrow(fit$V)),
+ as.integer(ncol(fit$V)),
+ as.integer(nrow(x.train)),
+ as.integer(nrow(x)),
+ as.integer(fit$n.features),
+ as.integer(fit$n.classes),
+ as.integer(kernel.idx),
+ fit$gamma,
+ fit$coef,
+ fit$degree,
+ fit$kernel.eigen.cutoff
+ )
+ S <- plotdata$ZV
+ y.pred.orig <- plotdata$y.pred
+ }
+
+ classes <- fit$classes
+ if (is.factor(y.pred.orig)) {
+ y.pred <- match(y.pred.orig, classes)
+ } else {
+ y.pred <- y.pred.orig
+ }
+
+ # Define some colors
+ point.blue <- rgb(31, 119, 180, maxColorValue=255)
+ point.orange <- rgb(255, 127, 14, maxColorValue=255)
+ point.green <- rgb(44, 160, 44, maxColorValue=255)
+ fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
+ fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
+ fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)
+
+ colors <- as.matrix(c(point.green, point.blue, point.orange))
+ markers <- as.matrix(c(15, 16, 17))
+
+ if (is.null(y.true)) {
+ col.vector <- colors[y.pred]
+ mark.vector <- markers[y.pred]
+ } else {
+ col.vector <- colors[y.true]
+ mark.vector <- markers[y.true]
+ }
+
+ par(pty="s")
+ if (center.plot) {
+ new.xlim <- c(min(min(S[, 1]), -1.2), max(max(S[, 1]), 1.2))
+ new.ylim <- c(min(min(S[, 2]), -0.75), max(max(S[, 2]), 1.2))
+ plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='',
+ asp=1, xlim=new.xlim, ylim=new.ylim)
+ } else {
+ plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='',
+ asp=1)
+ }
+
+ limits <- par("usr")
+ xmin <- limits[1]
+ xmax <- limits[2]
+ ymin <- limits[3]
+ ymax <- limits[4]
+
+ # draw the fixed boundaries
+ segments(0, 0, 0, ymin)
+ segments(0, 0, xmax, xmax/sqrt(3))
+ segments(xmin, abs(xmin)/sqrt(3), 0, 0)
+
+ if (with.margins) {
+ # margin from left below decision boundary to center
+ segments(xmin, -xmin/sqrt(3) - sqrt(4/3), -1, -1/sqrt(3), lty=2)
+
+ # margin from left center to down
+ segments(-1, -1/sqrt(3), -1, ymin, lty=2)
+
+ # margin from right center to middle
+ segments(1, -1/sqrt(3), 1, ymin, lty=2)
+
+ # margin from right center to right boundary
+ segments(1, -1/sqrt(3), xmax, xmax/sqrt(3) - sqrt(4/3), lty=2)
+
+ # margin from center to top left
+ segments(xmin, -xmin/sqrt(3) + sqrt(4/3), 0, sqrt(4/3), lty=2)
+
+ # margin from center to top right
+ segments(0, sqrt(4/3), xmax, xmax/sqrt(3) + sqrt(4/3), lty=2)
+ }
+
+ if (with.shading) {
+ # bottom left
+ polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) -
+ sqrt(4/3)), col=fill.green, border=NA)
+ # bottom right
+ polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3),
+ -1/sqrt(3)), col=fill.blue, border=NA)
+ # top
+ polygon(c(xmin, 0, xmax, xmax, xmin),
+ c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3),
+ ymax, ymax), col=fill.orange,
+ border=NA)
+ }
+
+ if (with.legend) {
+ offset <- abs(xmax - xmin) * 0.05
+ legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
+ }
+
+ invisible(fit)
+}