aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 17:10:23 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 17:10:23 +0100
commit2191af8180c4c2f655ecc4668765ed00332564ab (patch)
tree2867c407a060a81da68183285364c0a8a38678ed
parentReduce verbosity for refit (diff)
downloadrgensvm-2191af8180c4c2f655ecc4668765ed00332564ab.tar.gz
rgensvm-2191af8180c4c2f655ecc4668765ed00332564ab.zip
Add plot function for binary data
-rw-r--r--R/plot.gensvm.R110
1 files changed, 90 insertions, 20 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index ad69597..a41c91e 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -125,24 +125,19 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
y.pred <- y.pred.orig
}
- # Define some colors
- point.blue <- rgb(31, 119, 180, maxColorValue=255)
- point.orange <- rgb(255, 127, 14, maxColorValue=255)
- point.green <- rgb(44, 160, 44, maxColorValue=255)
- fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
- fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
- fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)
+ labels <- if(is.null(y.true)) y.pred else y.true
+ classes <- unique(labels)
- colors <- as.matrix(c(point.green, point.blue, point.orange))
- markers <- as.matrix(c(15, 16, 17))
+ colors <- gensvm.plot.colors(length(classes))
+ markers <- gensvm.plot.markers(length(classes))
- if (is.null(y.true)) {
- col.vector <- colors[y.pred]
- mark.vector <- markers[y.pred]
- } else {
- col.vector <- colors[y.true]
- mark.vector <- markers[y.true]
- }
+ indices <- match(labels, classes)
+
+ col.vector <- colors[indices]
+ mark.vector <- markers[indices]
+
+ if (fit$n.classes == 2)
+ S <- cbind(S, matrix(0, nrow(S), 1))
par(pty="s")
if (center.plot) {
@@ -157,6 +152,20 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
xlab='', asp=1, xlim=xlim, ylim=ylim, ...)
}
+ if (fit$n.classes == 3)
+ gensvm.plot.2d(classes, with.margins, with.shading, with.legend,
+ center.plot)
+ else
+ gensvm.plot.1d(classes, with.margins, with.shading, with.legend,
+ center.plot)
+
+ invisible(fit)
+}
+
+
+gensvm.plot.2d <- function(classes, with.margins, with.shading,
+ with.legend, center.plot)
+{
limits <- par("usr")
xmin <- limits[1]
xmax <- limits[2]
@@ -189,23 +198,84 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
}
if (with.shading) {
+ fill <- gensvm.fill.colors()
# bottom left
polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) -
- sqrt(4/3)), col=fill.green, border=NA)
+ sqrt(4/3)), col=fill$green, border=NA)
# bottom right
polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3),
- -1/sqrt(3)), col=fill.blue, border=NA)
+ -1/sqrt(3)), col=fill$blue, border=NA)
# top
polygon(c(xmin, 0, xmax, xmax, xmin),
c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3),
- ymax, ymax), col=fill.orange,
+ ymax, ymax), col=fill$orange,
border=NA)
}
if (with.legend) {
offset <- abs(xmax - xmin) * 0.05
+ colors <- gensvm.plot.colors()
+ markers <- gensvm.plot.markers()
legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
}
+}
- invisible(fit)
+gensvm.plot.1d <- function(classes, with.margins, with.shading, with.legend,
+ center.plot)
+{
+ limits <- par("usr")
+ xmin <- limits[1]
+ xmax <- limits[2]
+ ymin <- limits[3]
+ ymax <- limits[4]
+
+ # draw the fixed boundaries
+ segments(0, ymin, 0, ymax)
+
+ if (with.margins) {
+ segments(-1, ymin, -1, ymax, lty=2)
+ segments(1, ymin, 1, ymax, lty=2)
+ }
+ if (with.shading) {
+ fill <- gensvm.fill.colors()
+ polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, ymax, ymax),
+ col=fill$blue, border=NA)
+ polygon(c(1, xmax, xmax, 1), c(ymin, ymin, ymax, ymax), col=fill$orange,
+ border=NA)
+ }
+ if (with.legend) {
+ offset <- abs(xmax - xmin) * 0.05
+ colors <- gensvm.plot.colors(2)
+ markers <- gensvm.plot.markers(2)
+ legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
+ }
+}
+
+gensvm.plot.colors <- function(K=3)
+{
+ point.blue <- rgb(31, 119, 180, maxColorValue=255)
+ point.orange <- rgb(255, 127, 14, maxColorValue=255)
+ point.green <- rgb(44, 160, 44, maxColorValue=255)
+
+ if (K == 3)
+ colors <- as.matrix(c(point.green, point.blue, point.orange))
+ else
+ colors <- as.matrix(c(point.blue, point.orange))
+ return(colors)
+}
+
+gensvm.fill.colors <- function()
+{
+ fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
+ fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
+ fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)
+
+ fills <- list(blue=fill.blue, orange=fill.orange, green=fill.green)
+ return(fills)
+}
+
+gensvm.plot.markers <- function(K=3)
+{
+ markers <- as.vector(c(15, 16, 17))
+ return(markers[1:K])
}