aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 17:08:49 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-30 17:08:49 +0100
commit937772ceb77f3213119d736da8e2fad620d16add (patch)
tree2029365faa0e74d4c85e4c51d95548e145a67bfe
parentpass further arguments to plot function (diff)
downloadrgensvm-937772ceb77f3213119d736da8e2fad620d16add.tar.gz
rgensvm-937772ceb77f3213119d736da8e2fad620d16add.zip
Update some return values
-rw-r--r--R/gensvm.accuracy.R2
-rw-r--r--R/gensvm.grid.R5
-rw-r--r--R/plot.gensvm.R12
-rw-r--r--R/predict.gensvm.R5
4 files changed, 15 insertions, 9 deletions
diff --git a/R/gensvm.accuracy.R b/R/gensvm.accuracy.R
index 9a60411..c762dd7 100644
--- a/R/gensvm.accuracy.R
+++ b/R/gensvm.accuracy.R
@@ -31,7 +31,7 @@ gensvm.accuracy <- function(y.true, y.pred)
if (n != length(y.pred)) {
cat("Error: Can't compute accuracy if vector don't have the ",
"same length\n")
- return
+ return(-1)
}
return (sum(y.true == y.pred) / n)
diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R
index 613b718..6db768c 100644
--- a/R/gensvm.grid.R
+++ b/R/gensvm.grid.R
@@ -156,6 +156,11 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3,
n.features <- ncol(X)
n.classes <- length(unique(y))
+ if (n.objects != length(y)) {
+ cat("Error: X and y are not the same length.\n")
+ return(NULL)
+ }
+
if (is.character(param.grid)) {
if (param.grid == 'tiny') {
param.grid <- gensvm.load.tiny.grid()
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index 5bea7eb..ad69597 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -63,16 +63,16 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
- if (fit$n.classes != 3) {
- cat("Error: Can only plot with 3 classes\n")
- return
+ if (!(fit$n.classes %in% c(2,3))) {
+ cat("Error: Can only plot with 2 or 3 classes\n")
+ return(NULL)
}
# Sanity check
if (ncol(x) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
- return
+ return(NULL)
}
x.train <- fit$X.train
@@ -80,12 +80,12 @@ plot.gensvm <- function(fit, x, y.true=NULL, with.margins=TRUE,
cat("Error: The training data is needed to plot data for ",
"nonlinear GenSVM. This data is not present in the fitted ",
"model!\n", sep="")
- return
+ return(NULL)
}
if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
cat("Error: Number of features of fitted model and training data ",
"disagree.\n", sep="")
- return
+ return(NULL)
}
x <- as.matrix(x)
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index d7540f6..133b40e 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -54,7 +54,7 @@ predict.gensvm <- function(fit, x.test, ...)
if (ncol(x.test) != fit$n.features) {
cat("Error: Number of features of fitted model and testing",
"data disagree.\n")
- return
+ return(NULL)
}
x.train <- fit$X.train
@@ -62,11 +62,12 @@ predict.gensvm <- function(fit, x.test, ...)
cat("Error: The training data is needed to compute predictions for ",
"nonlinear GenSVM. This data is not present in the fitted ",
"model!\n", sep="")
+ return(NULL)
}
if (!is.null(x.train) && ncol(x.train) != fit$n.features) {
cat("Error: Number of features of fitted model and training",
"data disagree.\n")
- return
+ return(NULL)
}
if (fit$kernel == 'linear') {