aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
Diffstat (limited to 'R')
-rw-r--r--R/gensvm.R15
-rw-r--r--R/gensvm.grid.R35
-rw-r--r--R/validate.R70
3 files changed, 77 insertions, 43 deletions
diff --git a/R/gensvm.R b/R/gensvm.R
index 32ab61b..cc30b4b 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -133,21 +133,16 @@ gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6,
if (gamma == 'auto')
gamma <- 1.0/n.features
+ if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda,
+ epsilon=epsilon, gamma=gamma, weights=weights,
+ kernel=kernel))
+ return(NULL)
+
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
- if (length(weight.idx) == 0) {
- cat("Error: Incorrect weight specification. ",
- "Valid options are 'unit' and 'group'")
- return
- }
# Convert kernel to index (remember off-by-one for R vs. C)
kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
- if (length(kernel.idx) == 0) {
- cat("Error: Incorrect kernel specification. ",
- "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
- return
- }
seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V)
seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V)
diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R
index 2acdc2f..f954b79 100644
--- a/R/gensvm.grid.R
+++ b/R/gensvm.grid.R
@@ -162,7 +162,8 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3,
}
# Validate the range of the values for the gridsearch
- gensvm.validate.param.grid(param.grid)
+ if (!gensvm.validate.param.grid(param.grid))
+ return(NULL)
# Sort the parameter grid for efficient warm starts
param.grid <- gensvm.sort.param.grid(param.grid)
@@ -379,38 +380,6 @@ gensvm.generate.cv.idx <- function(n, folds)
return(cv.idx)
}
-gensvm.validate.param.grid <- function(df)
-{
- expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight",
- "kappa", "lambda", "p", "epsilon", "max.iter")
- for (name in colnames(df)) {
- if (!(name %in% expected.colnames)) {
- stop("Invalid header name supplied in parameter grid: ", name)
- }
- }
-
- conditions <- list(
- p=function(x) { x >= 1.0 && x <= 2.0 },
- kappa=function(x) { x > -1.0 },
- lambda=function(x) {x > 0.0 },
- epsilon=function(x) { x > 0.0 },
- gamma=function(x) { x != 0.0 },
- weight=function(x) { x %in% c("unit", "group") },
- kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
- )
-
- for (idx in 1:nrow(df)) {
- for (param in colnames(df)) {
- if (!(param %in% names(conditions)))
- next
- func <- conditions[[param]]
- value <- df[[param]][idx]
- if (!func(value))
- stop("Invalid value in grid for parameter: ", param)
- }
- }
-}
-
gensvm.cv.results <- function(results, param.grid, cv.idx, y.true,
scoring, return.train.score=TRUE)
{
diff --git a/R/validate.R b/R/validate.R
new file mode 100644
index 0000000..b0f3f39
--- /dev/null
+++ b/R/validate.R
@@ -0,0 +1,70 @@
+#' @title [internal] Validate parameters
+#'
+#' @export
+#' @keywords internal
+gensvm.validate.params <- function(p=NULL, kappa=NULL, lambda=NULL,
+ epsilon=NULL, gamma=NULL, weights=NULL,
+ kernel=NULL, ...)
+{
+ the.args <- as.list(match.call())
+ conditions <- gensvm.param.conditions()
+ for (param in names(the.args)) {
+ if (is.null(the.args[[param]]))
+ next
+ if (!(param %in% names(conditions)))
+ next
+ func <- conditions[[param]]
+ value <- eval(the.args[[param]])
+ if (!func(value)) {
+ cat(sprintf("Error: Parameter '%s' got invalid value: %s\n", param,
+ toString(value)))
+ return(FALSE)
+ }
+ }
+ return(TRUE)
+}
+
+#' @title [internal] Validate parameter grid
+#'
+#' @export
+#' @keywords internal
+gensvm.validate.param.grid <- function(df)
+{
+ expected.colnames <- c("kernel", "coef", "degree", "gamma", "weights",
+ "kappa", "lambda", "p", "epsilon", "max.iter")
+ for (name in colnames(df)) {
+ if (!(name %in% expected.colnames)) {
+ cat(sprintf("Error: Invalid name supplied in parameter grid: %s\n",
+ name))
+ return(FALSE)
+ }
+ }
+
+ conditions <- gensvm.param.conditions()
+ for (idx in 1:nrow(df)) {
+ for (param in colnames(df)) {
+ if (!(param %in% names(conditions)))
+ next
+ func <- conditions[[param]]
+ value <- df[[param]][idx]
+ if (!func(value)) {
+ cat(sprintf("Invalid value in grid for parameter: %s\n", param))
+ return(FALSE)
+ }
+ }
+ }
+ return(TRUE)
+}
+
+gensvm.param.conditions <- function()
+{
+ conditions <- list(
+ p=function(x) { x >= 1.0 && x <= 2.0 },
+ kappa=function(x) { x > -1.0 },
+ lambda=function(x) {x > 0.0 },
+ epsilon=function(x) { x > 0.0 },
+ gamma=function(x) { x != 0.0 },
+ weights=function(x) { x %in% c("unit", "group") },
+ kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") }
+ )
+}