diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-30 22:07:11 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-30 22:07:11 +0100 |
| commit | 9cbb676219df92e5600a1a19fe8d63ced10e1b28 (patch) | |
| tree | dc47284c86d7d93f71b1c888c724da98295525b0 /R | |
| parent | Return invisibly on error (diff) | |
| download | rgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.tar.gz rgensvm-9cbb676219df92e5600a1a19fe8d63ced10e1b28.zip | |
Fixes to get the input data from the call
Diffstat (limited to 'R')
| -rw-r--r-- | R/plot.gensvm.R | 66 | ||||
| -rw-r--r-- | R/predict.gensvm.R | 33 |
2 files changed, 41 insertions, 58 deletions
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R index ac40ccf..d0c1eb2 100644 --- a/R/plot.gensvm.R +++ b/R/plot.gensvm.R @@ -6,10 +6,9 @@ #' dimensional to easily visualize. #' #' @param fit A fitted \code{gensvm} object -#' @param x the dataset to plot -#' @param y.true the true data labels. If provided the objects will be colored -#' using the true labels instead of the predicted labels. This makes it easy to -#' identify misclassified objects. +#' @param x the dataset to plot (if NULL the training data is used) +#' @param y the labels to color points with (if NULL the predicted labels are +#' used) #' @param with.margins plot the margins #' @param with.shading show shaded areas for the class regions #' @param with.legend show the legend for the class labels @@ -48,19 +47,19 @@ #' fit <- gensvm(x, y) #' #' # plot the simplex space -#' plot(fit, x) +#' plot(fit) #' #' # plot and use the true colors (easier to spot misclassified samples) -#' plot(fit, x, y.true=y) +#' plot(fit, y) #' #' # plot only misclassified samples -#' x.mis <- x[predict(fit, x) != y, ] -#' y.mis.true <- y[predict(fit, x) != y] -#' plot(fit, x.mis) -#' plot(fit, x.mis, y.true=y.mis.true) +#' x.mis <- x[predict(fit) != y, ] +#' y.mis.true <- y[predict(fit) != y] +#' plot(fit, x.test=x.mis) +#' plot(fit, y.mis.true, x.test=x.mis) #' -plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, - with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, +plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, + with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, xlim=NULL, ylim=NULL, ...) { if (!(fit$n.classes %in% c(2,3))) { @@ -68,44 +67,32 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, return(invisible(NULL)) } + x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test + # Sanity check - if (ncol(x) != fit$n.features) { + if (ncol(x.test) != fit$n.features) { cat("Error: Number of features of fitted model and testing data disagree.\n") - invisible(NULL) - } - - x.train <- fit$X.train - if (fit$kernel != 'linear' && is.null(x.train)) { - cat("Error: The training data is needed to plot data for ", - "nonlinear GenSVM. This data is not present in the fitted ", - "model!\n", sep="") - invisible(NULL) - } - if (!is.null(x.train) && ncol(x.train) != fit$n.features) { - cat("Error: Number of features of fitted model and training data ", - "disagree.\n", sep="") - invisible(NULL) + return(invisible(NULL)) } - x <- as.matrix(x) - + x.train <- eval.parent(fit$call$x) if (fit$kernel == 'linear') { V <- coef(fit) - Z <- cbind(matrix(1, dim(x)[1], 1), x) + Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test)) S <- Z %*% V - y.pred.orig <- predict(fit, x) + y.pred.orig <- predict(fit, x.test) } else { kernels <- c("linear", "poly", "rbf", "sigmoid") kernel.idx <- which(kernels == fit$kernel) - 1 plotdata <- .Call("R_gensvm_plotdata_kernels", - as.matrix(x), + as.matrix(x.test), as.matrix(x.train), as.matrix(fit$V), as.integer(nrow(fit$V)), as.integer(ncol(fit$V)), as.integer(nrow(x.train)), - as.integer(nrow(x)), + as.integer(nrow(x.test)), as.integer(fit$n.features), as.integer(fit$n.classes), as.integer(kernel.idx), @@ -125,16 +112,13 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, y.pred <- y.pred.orig } - labels <- if(is.null(y.true)) y.pred else y.true - classes <- unique(labels) - - colors <- gensvm.plot.colors(length(classes)) - markers <- gensvm.plot.markers(length(classes)) + labels <- if(missing(y)) y.pred else match(y, classes) - indices <- match(labels, classes) + colors <- gensvm.plot.colors(fit$n.classes) + markers <- gensvm.plot.markers(fit$n.classes) - col.vector <- colors[indices] - mark.vector <- markers[indices] + col.vector <- colors[labels] + mark.vector <- markers[labels] if (fit$n.classes == 2) S <- cbind(S, matrix(0, nrow(S), 1)) diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 558f80b..2cf7401 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -4,8 +4,9 @@ #' fitted GenSVM model. #' #' @param fit Fitted \code{gensvm} object -#' @param x.test Matrix of new values for \code{x} for which predictions need -#' to be made. +#' @param newdata Matrix of new data for which predictions need to be made. +#' @param add.rownames add the rownames from the training data to the +#' predictions #' @param \dots further arguments are ignored #' #' @return a vector of class labels, with the same type as the original class @@ -41,7 +42,7 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y.test, y.test.pred) #' -predict.gensvm <- function(fit, x.test, ...) +predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -50,26 +51,19 @@ predict.gensvm <- function(fit, x.test, ...) ## However, if you actually implement it and compare, we find that calling ## the C implementation is *much* faster than doing it in R. + if (missing(newdata)) { + newdata <- eval.parent(fit$call$x) + } + x.test <- as.matrix(newdata) + # Sanity check if (ncol(x.test) != fit$n.features) { cat("Error: Number of features of fitted model and testing", "data disagree.\n") - invisible(NULL) - } - - x.train <- fit$X.train - if (fit$kernel != 'linear' && is.null(x.train)) { - cat("Error: The training data is needed to compute predictions for ", - "nonlinear GenSVM. This data is not present in the fitted ", - "model!\n", sep="") - invisible(NULL) - } - if (!is.null(x.train) && ncol(x.train) != fit$n.features) { - cat("Error: Number of features of fitted model and training", - "data disagree.\n") - invisible(NULL) + return(invisible(NULL)) } + x.train <- eval.parent(fit$call$x) if (fit$kernel == 'linear') { y.pred.c <- .Call("R_gensvm_predict", as.matrix(x.test), @@ -101,5 +95,10 @@ predict.gensvm <- function(fit, x.test, ...) yhat <- fit$classes[y.pred.c] + if (add.rownames) { + yhat <- matrix(yhat, length(yhat), 1) + rownames(yhat) <- rownames(x.train) + } + return(yhat) } |
