diff options
Diffstat (limited to 'R')
| -rw-r--r-- | R/gensvm.R | 15 | ||||
| -rw-r--r-- | R/gensvm.grid.R | 35 | ||||
| -rw-r--r-- | R/validate.R | 70 |
3 files changed, 77 insertions, 43 deletions
@@ -133,21 +133,16 @@ gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, if (gamma == 'auto') gamma <- 1.0/n.features + if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda, + epsilon=epsilon, gamma=gamma, weights=weights, + kernel=kernel)) + return(NULL) + # Convert weights to index weight.idx <- which(c("unit", "group") == weights) - if (length(weight.idx) == 0) { - cat("Error: Incorrect weight specification. ", - "Valid options are 'unit' and 'group'") - return - } # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 - if (length(kernel.idx) == 0) { - cat("Error: Incorrect kernel specification. ", - "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") - return - } seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V) seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V) diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R index 2acdc2f..f954b79 100644 --- a/R/gensvm.grid.R +++ b/R/gensvm.grid.R @@ -162,7 +162,8 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3, } # Validate the range of the values for the gridsearch - gensvm.validate.param.grid(param.grid) + if (!gensvm.validate.param.grid(param.grid)) + return(NULL) # Sort the parameter grid for efficient warm starts param.grid <- gensvm.sort.param.grid(param.grid) @@ -379,38 +380,6 @@ gensvm.generate.cv.idx <- function(n, folds) return(cv.idx) } -gensvm.validate.param.grid <- function(df) -{ - expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight", - "kappa", "lambda", "p", "epsilon", "max.iter") - for (name in colnames(df)) { - if (!(name %in% expected.colnames)) { - stop("Invalid header name supplied in parameter grid: ", name) - } - } - - conditions <- list( - p=function(x) { x >= 1.0 && x <= 2.0 }, - kappa=function(x) { x > -1.0 }, - lambda=function(x) {x > 0.0 }, - epsilon=function(x) { x > 0.0 }, - gamma=function(x) { x != 0.0 }, - weight=function(x) { x %in% c("unit", "group") }, - kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } - ) - - for (idx in 1:nrow(df)) { - for (param in colnames(df)) { - if (!(param %in% names(conditions))) - next - func <- conditions[[param]] - value <- df[[param]][idx] - if (!func(value)) - stop("Invalid value in grid for parameter: ", param) - } - } -} - gensvm.cv.results <- function(results, param.grid, cv.idx, y.true, scoring, return.train.score=TRUE) { diff --git a/R/validate.R b/R/validate.R new file mode 100644 index 0000000..b0f3f39 --- /dev/null +++ b/R/validate.R @@ -0,0 +1,70 @@ +#' @title [internal] Validate parameters +#' +#' @export +#' @keywords internal +gensvm.validate.params <- function(p=NULL, kappa=NULL, lambda=NULL, + epsilon=NULL, gamma=NULL, weights=NULL, + kernel=NULL, ...) +{ + the.args <- as.list(match.call()) + conditions <- gensvm.param.conditions() + for (param in names(the.args)) { + if (is.null(the.args[[param]])) + next + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- eval(the.args[[param]]) + if (!func(value)) { + cat(sprintf("Error: Parameter '%s' got invalid value: %s\n", param, + toString(value))) + return(FALSE) + } + } + return(TRUE) +} + +#' @title [internal] Validate parameter grid +#' +#' @export +#' @keywords internal +gensvm.validate.param.grid <- function(df) +{ + expected.colnames <- c("kernel", "coef", "degree", "gamma", "weights", + "kappa", "lambda", "p", "epsilon", "max.iter") + for (name in colnames(df)) { + if (!(name %in% expected.colnames)) { + cat(sprintf("Error: Invalid name supplied in parameter grid: %s\n", + name)) + return(FALSE) + } + } + + conditions <- gensvm.param.conditions() + for (idx in 1:nrow(df)) { + for (param in colnames(df)) { + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- df[[param]][idx] + if (!func(value)) { + cat(sprintf("Invalid value in grid for parameter: %s\n", param)) + return(FALSE) + } + } + } + return(TRUE) +} + +gensvm.param.conditions <- function() +{ + conditions <- list( + p=function(x) { x >= 1.0 && x <= 2.0 }, + kappa=function(x) { x > -1.0 }, + lambda=function(x) {x > 0.0 }, + epsilon=function(x) { x > 0.0 }, + gamma=function(x) { x != 0.0 }, + weights=function(x) { x %in% c("unit", "group") }, + kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } + ) +} |
