diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-30 17:08:49 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-30 17:08:49 +0100 |
| commit | 937772ceb77f3213119d736da8e2fad620d16add (patch) | |
| tree | 2029365faa0e74d4c85e4c51d95548e145a67bfe | |
| parent | pass further arguments to plot function (diff) | |
| download | rgensvm-937772ceb77f3213119d736da8e2fad620d16add.tar.gz rgensvm-937772ceb77f3213119d736da8e2fad620d16add.zip | |
Update some return values
| -rw-r--r-- | R/gensvm.accuracy.R | 2 | ||||
| -rw-r--r-- | R/gensvm.grid.R | 5 | ||||
| -rw-r--r-- | R/plot.gensvm.R | 12 | ||||
| -rw-r--r-- | R/predict.gensvm.R | 5 |
4 files changed, 15 insertions, 9 deletions
diff --git a/R/gensvm.accuracy.R b/R/gensvm.accuracy.R index 9a60411..c762dd7 100644 --- a/R/gensvm.accuracy.R +++ b/R/gensvm.accuracy.R @@ -31,7 +31,7 @@ gensvm.accuracy <- function(y.true, y.pred) if (n != length(y.pred)) { cat("Error: Can't compute accuracy if vector don't have the ", "same length\n") - return + return(-1) } return (sum(y.true == y.pred) / n) diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R index 613b718..6db768c 100644 --- a/R/gensvm.grid.R +++ b/R/gensvm.grid.R @@ -156,6 +156,11 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3, n.features <- ncol(X) n.classes <- length(unique(y)) + if (n.objects != length(y)) { + cat("Error: X and y are not the same length.\n") + return(NULL) + } + if (is.character(param.grid)) { if (param.grid == 'tiny') { param.grid <- gensvm.load.tiny.grid() diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R index 5bea7eb..ad69597 100644 --- a/R/plot.gensvm.R +++ b/R/plot.gensvm.R @@ -63,16 +63,16 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, xlim=NULL, ylim=NULL, ...) { - if (fit$n.classes != 3) { - cat("Error: Can only plot with 3 classes\n") - return + if (!(fit$n.classes %in% c(2,3))) { + cat("Error: Can only plot with 2 or 3 classes\n") + return(NULL) } # Sanity check if (ncol(x) != fit$n.features) { cat("Error: Number of features of fitted model and testing data disagree.\n") - return + return(NULL) } x.train <- fit$X.train @@ -80,12 +80,12 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE, cat("Error: The training data is needed to plot data for ", "nonlinear GenSVM. This data is not present in the fitted ", "model!\n", sep="") - return + return(NULL) } if (!is.null(x.train) && ncol(x.train) != fit$n.features) { cat("Error: Number of features of fitted model and training data ", "disagree.\n", sep="") - return + return(NULL) } x <- as.matrix(x) diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index d7540f6..133b40e 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -54,7 +54,7 @@ predict.gensvm <- function(fit, x.test, ...) if (ncol(x.test) != fit$n.features) { cat("Error: Number of features of fitted model and testing", "data disagree.\n") - return + return(NULL) } x.train <- fit$X.train @@ -62,11 +62,12 @@ predict.gensvm <- function(fit, x.test, ...) cat("Error: The training data is needed to compute predictions for ", "nonlinear GenSVM. This data is not present in the fitted ", "model!\n", sep="") + return(NULL) } if (!is.null(x.train) && ncol(x.train) != fit$n.features) { cat("Error: Number of features of fitted model and training", "data disagree.\n") - return + return(NULL) } if (fit$kernel == 'linear') { |
