1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
|
#' @title Plot the simplex space of the fitted GenSVM model
#'
#' @description This function creates a plot of the simplex space for a fitted
#' GenSVM model and the given data set. This function works for dataset with
#' two or three classes. For more than 3 classes, the simplex space is too high
#' dimensional to easily visualize.
#'
#' @param x A fitted \code{gensvm} object
#' @param labels the labels to color points with. If this is omitted the
#' predicted labels are used.
#' @param newdata the dataset to plot. If this is NULL the training data is
#' used.
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
#' @param center.plot ensure that the boundaries and margins are always visible
#' in the plot
#' @param xlim allows the user to force certain plot limits. If set, these
#' bounds will be used for the horizontal axis.
#' @param ylim allows the user to force certain plot limits. If set, these
#' bounds will be used for the vertical axis and the value of center.plot will
#' be ignored
#' @param ... further arguments are passed to the builtin plot() function
#'
#' @return returns the object passed as input
#'
#' @author
#' Gerrit J.J. van den Burg, Patrick J.F. Groenen \cr
#' Maintainer: Gerrit J.J. van den Burg <gertjanvandenburg@gmail.com>
#'
#' @references
#' Van den Burg, G.J.J. and Groenen, P.J.F. (2016). \emph{GenSVM: A Generalized
#' Multiclass Support Vector Machine}, Journal of Machine Learning Research,
#' 17(225):1--42. URL \url{http://jmlr.org/papers/v17/14-526.html}.
#'
#' @seealso
#' \code{\link{plot.gensvm.grid}}, \code{\link{predict.gensvm}},
#' \code{\link{gensvm}}, \code{\link{gensvm-package}}
#'
#' @method plot gensvm
#'
#' @export
#'
#' @importFrom grDevices rgb
#' @importFrom graphics legend par plot polygon segments
#'
#' @examples
#' x <- iris[, -5]
#' y <- iris[, 5]
#'
#' # train the model
#' fit <- gensvm(x, y)
#'
#' # plot the simplex space
#' plot(fit)
#'
#' # plot and use the true colors (easier to spot misclassified samples)
#' plot(fit, y)
#'
#' # plot only misclassified samples
#' x.mis <- x[predict(fit) != y, ]
#' y.mis.true <- y[predict(fit) != y]
#' plot(fit, newdata=x.mis)
#' plot(fit, y.mis.true, newdata=x.mis)
#'
#' # plot a 2-d model
#' xx <- x[y %in% c('versicolor', 'virginica'), ]
#' yy <- y[y %in% c('versicolor', 'virginica')]
#' fit <- gensvm(xx, yy, kernel='rbf', max.iter=1000)
#' plot(fit)
#'
plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE,
with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
fit <- x
if (!(fit$n.classes %in% c(2,3))) {
cat("Error: Can only plot with 2 or 3 classes\n")
return(invisible(NULL))
}
newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata
# Sanity check
if (ncol(newdata) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
return(invisible(NULL))
}
x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
V <- coef(fit)
Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata))
S <- Z %*% V
y.pred <- predict(fit, newdata)
y.pred.int <- match(y.pred, fit$classes)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
kernel.idx <- which(kernels == fit$kernel) - 1
plotdata <- .Call("R_gensvm_plotdata_kernels",
as.matrix(newdata),
as.matrix(x.train),
as.matrix(fit$V),
as.integer(nrow(fit$V)),
as.integer(ncol(fit$V)),
as.integer(nrow(x.train)),
as.integer(nrow(newdata)),
as.integer(fit$n.features),
as.integer(fit$n.classes),
as.integer(kernel.idx),
fit$gamma,
fit$coef,
fit$degree,
fit$kernel.eigen.cutoff
)
S <- plotdata$ZV
y.pred.int <- plotdata$y.pred
}
colors <- gensvm.plot.colors(fit$n.classes)
markers <- gensvm.plot.markers(fit$n.classes)
indices <- if (missing(labels)) y.pred.int else match(labels, fit$classes)
col.vector <- colors[indices]
mark.vector <- markers[indices]
if (fit$n.classes == 2)
S <- cbind(S, matrix(0, nrow(S), 1))
par(pty="s")
if (center.plot) {
if (is.null(xlim))
xlim <- c(min(min(S[, 1]), -1.2), max(max(S[, 1]), 1.2))
if (is.null(ylim))
ylim <- c(min(min(S[, 2]), -0.75), max(max(S[, 2]), 1.2))
plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='', xlab='',
asp=1, xlim=xlim, ylim=ylim, ...)
} else {
plot(S[, 1], S[, 2], col=col.vector, pch=mark.vector, ylab='',
xlab='', asp=1, xlim=xlim, ylim=ylim, ...)
}
if (fit$n.classes == 3)
gensvm.plot.2d(fit$classes, with.margins, with.shading, with.legend,
center.plot)
else
gensvm.plot.1d(fit$classes, with.margins, with.shading, with.legend,
center.plot)
invisible(fit)
}
gensvm.plot.2d <- function(classes, with.margins, with.shading,
with.legend, center.plot)
{
limits <- par("usr")
xmin <- limits[1]
xmax <- limits[2]
ymin <- limits[3]
ymax <- limits[4]
# draw the fixed boundaries
segments(0, 0, 0, ymin)
segments(0, 0, xmax, xmax/sqrt(3))
segments(xmin, abs(xmin)/sqrt(3), 0, 0)
if (with.margins) {
# margin from left below decision boundary to center
segments(xmin, -xmin/sqrt(3) - sqrt(4/3), -1, -1/sqrt(3), lty=2)
# margin from left center to down
segments(-1, -1/sqrt(3), -1, ymin, lty=2)
# margin from right center to middle
segments(1, -1/sqrt(3), 1, ymin, lty=2)
# margin from right center to right boundary
segments(1, -1/sqrt(3), xmax, xmax/sqrt(3) - sqrt(4/3), lty=2)
# margin from center to top left
segments(xmin, -xmin/sqrt(3) + sqrt(4/3), 0, sqrt(4/3), lty=2)
# margin from center to top right
segments(0, sqrt(4/3), xmax, xmax/sqrt(3) + sqrt(4/3), lty=2)
}
if (with.shading) {
fill <- gensvm.fill.colors()
# bottom left
polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, -1/sqrt(3), -xmin/sqrt(3) -
sqrt(4/3)), col=fill$green, border=NA)
# bottom right
polygon(c(1, xmax, xmax, 1), c(ymin, ymin, xmax/sqrt(3) - sqrt(4/3),
-1/sqrt(3)), col=fill$blue, border=NA)
# top
polygon(c(xmin, 0, xmax, xmax, xmin),
c(-xmin/sqrt(3) + sqrt(4/3), sqrt(4/3), xmax/sqrt(3) + sqrt(4/3),
ymax, ymax), col=fill$orange,
border=NA)
}
if (with.legend) {
offset <- abs(xmax - xmin) * 0.05
colors <- gensvm.plot.colors()
markers <- gensvm.plot.markers()
legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
}
}
gensvm.plot.1d <- function(classes, with.margins, with.shading, with.legend,
center.plot)
{
limits <- par("usr")
xmin <- limits[1]
xmax <- limits[2]
ymin <- limits[3]
ymax <- limits[4]
# draw the fixed boundaries
segments(0, ymin, 0, ymax)
if (with.margins) {
segments(-1, ymin, -1, ymax, lty=2)
segments(1, ymin, 1, ymax, lty=2)
}
if (with.shading) {
fill <- gensvm.fill.colors()
polygon(c(xmin, -1, -1, xmin), c(ymin, ymin, ymax, ymax),
col=fill$blue, border=NA)
polygon(c(1, xmax, xmax, 1), c(ymin, ymin, ymax, ymax), col=fill$orange,
border=NA)
}
if (with.legend) {
offset <- abs(xmax - xmin) * 0.05
colors <- gensvm.plot.colors(2)
markers <- gensvm.plot.markers(2)
legend(xmax + offset, ymax, classes, col=colors, pch=markers, xpd=T)
}
}
gensvm.plot.colors <- function(K=3)
{
point.blue <- rgb(31, 119, 180, maxColorValue=255)
point.orange <- rgb(255, 127, 14, maxColorValue=255)
point.green <- rgb(44, 160, 44, maxColorValue=255)
if (K == 3)
colors <- as.matrix(c(point.green, point.blue, point.orange))
else
colors <- as.matrix(c(point.blue, point.orange))
return(colors)
}
gensvm.fill.colors <- function()
{
fill.blue <- rgb(31, 119, 180, 51, maxColorValue=255)
fill.orange <- rgb(255, 127, 14, 51, maxColorValue=255)
fill.green <- rgb(44, 160, 44, 51, maxColorValue=255)
fills <- list(blue=fill.blue, orange=fill.orange, green=fill.green)
return(fills)
}
gensvm.plot.markers <- function(K=3)
{
markers <- as.vector(c(15, 16, 17))
return(markers[1:K])
}
|