aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-28 13:41:14 +0100
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-03-28 13:41:14 +0100
commit75d6ed3e5d919b4e5b7bd1e81283131c92530d26 (patch)
tree907f9de5769412c25e61d0c67a50089e6d8e1d6e
parentMake sure we use "weights" everywhere instead of "weight" (diff)
downloadrgensvm-75d6ed3e5d919b4e5b7bd1e81283131c92530d26.tar.gz
rgensvm-75d6ed3e5d919b4e5b7bd1e81283131c92530d26.zip
Validate params in gensvm() function
Parameter validation was only done for some parameters in the gensvm() function and for the parameter grid in gensvm.grid() With this commit the parameters will be tested properly for both functions.
-rw-r--r--NAMESPACE2
-rw-r--r--R/gensvm.R15
-rw-r--r--R/gensvm.grid.R35
-rw-r--r--R/validate.R70
-rw-r--r--man/gensvm.validate.param.grid.Rd10
-rw-r--r--man/gensvm.validate.params.Rd11
6 files changed, 100 insertions, 43 deletions
diff --git a/NAMESPACE b/NAMESPACE
index a70ecfd..6d37b12 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -17,4 +17,6 @@ export(gensvm.load.tiny.grid)
export(gensvm.maxabs.scale)
export(gensvm.refit)
export(gensvm.train.test.split)
+export(gensvm.validate.param.grid)
+export(gensvm.validate.params)
useDynLib(gensvm_wrapper, .registration = TRUE)
diff --git a/R/gensvm.R b/R/gensvm.R
index 32ab61b..cc30b4b 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -133,21 +133,16 @@ gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
if (gamma == 'auto')
gamma <- 1.0/n.features
+ if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda,
+ epsilon=epsilon, gamma=gamma, weights=weights,
+ kernel=kernel))
+ return(NULL)
+
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
- if (length(weight.idx) == 0) {
- cat("Error: Incorrect weight specification. ",
- "Valid options are 'unit' and 'group'")
- return
- }
# Convert kernel to index (remember off-by-one for R vs. C)
kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
- if (length(kernel.idx) == 0) {
- cat("Error: Incorrect kernel specification. ",
- "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
- return
- }
seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V)
seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V)
diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R
index 2acdc2f..f954b79 100644
--- a/R/gensvm.grid.R
+++ b/R/gensvm.grid.R
@@ -162,7 +162,8 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3,
}
# Validate the range of the values for the gridsearch
- gensvm.validate.param.grid(param.grid)
+ if (!gensvm.validate.param.grid(param.grid))
+ return(NULL)
# Sort the parameter grid for efficient warm starts
param.grid <- gensvm.sort.param.grid(param.grid)
@@ -379,38 +380,6 @@ gensvm.generate.cv.idx <- function(n, folds)
return(cv.idx)
}
-gensvm.validate.param.grid <- function(df)
-{
- expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight",
- "kappa", "lambda", "p", "epsilon", "max.iter")
- for (name in colnames(df)) {
- if (!(name %in% expected.colnames)) {
- stop("Invalid header name supplied in parameter grid: ", name)
- }
- }
-
- conditions <- list(
- p=function(x) { x >= 1.0 && x <= 2.0 },
- kappa=function(x) { x > -1.0 },
- lambda=function(x) {x > 0.0 },
- epsilon=function(x) { x > 0.0 },
- gamma=function(x) { x != 0.0 },
- weight=function(x) { x %in% c("unit", "group") },
- kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
- )
-
- for (idx in 1:nrow(df)) {
- for (param in colnames(df)) {
- if (!(param %in% names(conditions)))
- next
- func <- conditions[[param]]
- value <- df[[param]][idx]
- if (!func(value))
- stop("Invalid value in grid for parameter: ", param)
- }
- }
-}
-
gensvm.cv.results <- function(results, param.grid, cv.idx, y.true,
scoring, return.train.score=TRUE)
{
diff --git a/R/validate.R b/R/validate.R
new file mode 100644
index 0000000..b0f3f39
--- /dev/null
+++ b/R/validate.R
@@ -0,0 +1,70 @@
+#' @title [internal] Validate parameters
+#'
+#' @export
+#' @keywords internal
+gensvm.validate.params <- function(p=NULL, kappa=NULL, lambda=NULL,
+ epsilon=NULL, gamma=NULL, weights=NULL,
+ kernel=NULL, ...)
+{
+ the.args <- as.list(match.call())
+ conditions <- gensvm.param.conditions()
+ for (param in names(the.args)) {
+ if (is.null(the.args[[param]]))
+ next
+ if (!(param %in% names(conditions)))
+ next
+ func <- conditions[[param]]
+ value <- eval(the.args[[param]])
+ if (!func(value)) {
+ cat(sprintf("Error: Parameter '%s' got invalid value: %s\n", param,
+ toString(value)))
+ return(FALSE)
+ }
+ }
+ return(TRUE)
+}
+
+#' @title [internal] Validate parameter grid
+#'
+#' @export
+#' @keywords internal
+gensvm.validate.param.grid <- function(df)
+{
+ expected.colnames <- c("kernel", "coef", "degree", "gamma", "weights",
+ "kappa", "lambda", "p", "epsilon", "max.iter")
+ for (name in colnames(df)) {
+ if (!(name %in% expected.colnames)) {
+ cat(sprintf("Error: Invalid name supplied in parameter grid: %s\n",
+ name))
+ return(FALSE)
+ }
+ }
+
+ conditions <- gensvm.param.conditions()
+ for (idx in 1:nrow(df)) {
+ for (param in colnames(df)) {
+ if (!(param %in% names(conditions)))
+ next
+ func <- conditions[[param]]
+ value <- df[[param]][idx]
+ if (!func(value)) {
+ cat(sprintf("Invalid value in grid for parameter: %s\n", param))
+ return(FALSE)
+ }
+ }
+ }
+ return(TRUE)
+}
+
+gensvm.param.conditions <- function()
+{
+ conditions <- list(
+ p=function(x) { x >= 1.0 && x <= 2.0 },
+ kappa=function(x) { x > -1.0 },
+ lambda=function(x) {x > 0.0 },
+ epsilon=function(x) { x > 0.0 },
+ gamma=function(x) { x != 0.0 },
+ weights=function(x) { x %in% c("unit", "group") },
+ kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
+ )
+}
diff --git a/man/gensvm.validate.param.grid.Rd b/man/gensvm.validate.param.grid.Rd
new file mode 100644
index 0000000..5528f16
--- /dev/null
+++ b/man/gensvm.validate.param.grid.Rd
@@ -0,0 +1,10 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/validate.R
+\name{gensvm.validate.param.grid}
+\alias{gensvm.validate.param.grid}
+\title{[internal] Validate parameter grid}
+\usage{
+gensvm.validate.param.grid(df)
+}
+\keyword{internal}
+
diff --git a/man/gensvm.validate.params.Rd b/man/gensvm.validate.params.Rd
new file mode 100644
index 0000000..1d249a9
--- /dev/null
+++ b/man/gensvm.validate.params.Rd
@@ -0,0 +1,11 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/validate.R
+\name{gensvm.validate.params}
+\alias{gensvm.validate.params}
+\title{[internal] Validate parameters}
+\usage{
+gensvm.validate.params(p = NULL, kappa = NULL, lambda = NULL,
+ epsilon = NULL, gamma = NULL, weights = NULL, kernel = NULL, ...)
+}
+\keyword{internal}
+