aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 22:07:11 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 22:07:11 +0100
commit9cbb676219df92e5600a1a19fe8d63ced10e1b28 (patch)
treedc47284c86d7d93f71b1c888c724da98295525b0
parentReturn invisibly on error (diff)
downloadrgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.tar.gz
rgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.zip
Fixes to get the input data from the call
-rw-r--r--R/plot.gensvm.R66
-rw-r--r--R/predict.gensvm.R33
2 files changed, 41 insertions, 58 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index ac40ccf..d0c1eb2 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -6,10 +6,9 @@
#' dimensional to easily visualize.
#'
#' @param fit A fitted \code{gensvm} object
-#' @param x the dataset to plot
-#' @param y.true the true data labels. If provided the objects will be colored
-#' using the true labels instead of the predicted labels. This makes it easy to
-#' identify misclassified objects.
+#' @param x the dataset to plot (if NULL the training data is used)
+#' @param y the labels to color points with (if NULL the predicted labels are
+#' used)
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
@@ -48,19 +47,19 @@
#' fit <- gensvm(x, y)
#'
#' # plot the simplex space
-#' plot(fit, x)
+#' plot(fit)
#'
#' # plot and use the true colors (easier to spot misclassified samples)
-#' plot(fit, x, y.true=y)
+#' plot(fit, y)
#'
#' # plot only misclassified samples
-#' x.mis <- x[predict(fit, x) != y, ]
-#' y.mis.true <- y[predict(fit, x) != y]
-#' plot(fit, x.mis)
-#' plot(fit, x.mis, y.true=y.mis.true)
+#' x.mis <- x[predict(fit) != y, ]
+#' y.mis.true <- y[predict(fit) != y]
+#' plot(fit, x.test=x.mis)
+#' plot(fit, y.mis.true, x.test=x.mis)
#'
-plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
- with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
+plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
+ with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
if (!(fit$n.classes %in% c(2,3))) {
@@ -68,44 +67,32 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
return(invisible(NULL))
}
+ x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test
+
# Sanity check
- if (ncol(x) != fit$n.features) {
+ if (ncol(x.test) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
- invisible(NULL)
- }
-
- x.train <- fit$X.train
- if (fit$kernel != 'linear' && is.null(x.train)) {
- cat("Error: The training data is needed to plot data for ",
- "nonlinear GenSVM. This data is not present in the fitted ",
- "model!\n", sep="")
- invisible(NULL)
- }
- if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
- cat("Error: Number of features of fitted model and training data ",
- "disagree.\n", sep="")
- invisible(NULL)
+ return(invisible(NULL))
}
- x <- as.matrix(x)
-
+ x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
V <- coef(fit)
- Z <- cbind(matrix(1, dim(x)[1], 1), x)
+ Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test))
S <- Z %*% V
- y.pred.orig <- predict(fit, x)
+ y.pred.orig <- predict(fit, x.test)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
kernel.idx <- which(kernels == fit$kernel) - 1
plotdata <- .Call("R_gensvm_plotdata_kernels",
- as.matrix(x),
+ as.matrix(x.test),
as.matrix(x.train),
as.matrix(fit$V),
as.integer(nrow(fit$V)),
as.integer(ncol(fit$V)),
as.integer(nrow(x.train)),
- as.integer(nrow(x)),
+ as.integer(nrow(x.test)),
as.integer(fit$n.features),
as.integer(fit$n.classes),
as.integer(kernel.idx),
@@ -125,16 +112,13 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
y.pred <- y.pred.orig
}
- labels <- if(is.null(y.true)) y.pred else y.true
- classes <- unique(labels)
-
- colors <- gensvm.plot.colors(length(classes))
- markers <- gensvm.plot.markers(length(classes))
+ labels <- if(missing(y)) y.pred else match(y, classes)
- indices <- match(labels, classes)
+ colors <- gensvm.plot.colors(fit$n.classes)
+ markers <- gensvm.plot.markers(fit$n.classes)
- col.vector <- colors[indices]
- mark.vector <- markers[indices]
+ col.vector <- colors[labels]
+ mark.vector <- markers[labels]
if (fit$n.classes == 2)
S <- cbind(S, matrix(0, nrow(S), 1))
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 558f80b..2cf7401 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -4,8 +4,9 @@
#' fitted GenSVM model.
#'
#' @param fit Fitted \code{gensvm} object
-#' @param x.test Matrix of new values for \code{x} for which predictions need
-#' to be made.
+#' @param newdata Matrix of new data for which predictions need to be made.
+#' @param add.rownames add the rownames from the training data to the
+#' predictions
#' @param \dots further arguments are ignored
#'
#' @return a vector of class labels, with the same type as the original class
@@ -41,7 +42,7 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y.test, y.test.pred)
#'
-predict.gensvm <- function(fit, x.test, ...)
+predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -50,26 +51,19 @@ predict.gensvm <- function(fit, x.test, ...)
## However, if you actually implement it and compare, we find that calling
## the C implementation is *much* faster than doing it in R.
+ if (missing(newdata)) {
+ newdata <- eval.parent(fit$call$x)
+ }
+ x.test <- as.matrix(newdata)
+
# Sanity check
if (ncol(x.test) != fit$n.features) {
cat("Error: Number of features of fitted model and testing",
"data disagree.\n")
- invisible(NULL)
- }
-
- x.train <- fit$X.train
- if (fit$kernel != 'linear' && is.null(x.train)) {
- cat("Error: The training data is needed to compute predictions for ",
- "nonlinear GenSVM. This data is not present in the fitted ",
- "model!\n", sep="")
- invisible(NULL)
- }
- if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
- cat("Error: Number of features of fitted model and training",
- "data disagree.\n")
- invisible(NULL)
+ return(invisible(NULL))
}
+ x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
y.pred.c <- .Call("R_gensvm_predict",
as.matrix(x.test),
@@ -101,5 +95,10 @@ predict.gensvm <- function(fit, x.test, ...)
yhat <- fit$classes[y.pred.c]
+ if (add.rownames) {
+ yhat <- matrix(yhat, length(yhat), 1)
+ rownames(yhat) <- rownames(x.train)
+ }
+
return(yhat)
}