diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-28 13:41:14 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-03-28 13:41:14 +0100 |
| commit | 75d6ed3e5d919b4e5b7bd1e81283131c92530d26 (patch) | |
| tree | 907f9de5769412c25e61d0c67a50089e6d8e1d6e /R | |
| parent | Make sure we use "weights" everywhere instead of "weight" (diff) | |
| download | rgensvm-75d6ed3e5d919b4e5b7bd1e81283131c92530d26.tar.gz rgensvm-75d6ed3e5d919b4e5b7bd1e81283131c92530d26.zip | |
Validate params in gensvm() function
Parameter validation was only done for some parameters in the
gensvm() function and for the parameter grid in gensvm.grid()
With this commit the parameters will be tested properly for
both functions.
Diffstat (limited to 'R')
| -rw-r--r-- | R/gensvm.R | 15 | ||||
| -rw-r--r-- | R/gensvm.grid.R | 35 | ||||
| -rw-r--r-- | R/validate.R | 70 |
3 files changed, 77 insertions, 43 deletions
@@ -133,21 +133,16 @@ gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, if (gamma == 'auto') gamma <- 1.0/n.features + if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda, + epsilon=epsilon, gamma=gamma, weights=weights, + kernel=kernel)) + return(NULL) + # Convert weights to index weight.idx <- which(c("unit", "group") == weights) - if (length(weight.idx) == 0) { - cat("Error: Incorrect weight specification. ", - "Valid options are 'unit' and 'group'") - return - } # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 - if (length(kernel.idx) == 0) { - cat("Error: Incorrect kernel specification. ", - "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") - return - } seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V) seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V) diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R index 2acdc2f..f954b79 100644 --- a/R/gensvm.grid.R +++ b/R/gensvm.grid.R @@ -162,7 +162,8 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3, } # Validate the range of the values for the gridsearch - gensvm.validate.param.grid(param.grid) + if (!gensvm.validate.param.grid(param.grid)) + return(NULL) # Sort the parameter grid for efficient warm starts param.grid <- gensvm.sort.param.grid(param.grid) @@ -379,38 +380,6 @@ gensvm.generate.cv.idx <- function(n, folds) return(cv.idx) } -gensvm.validate.param.grid <- function(df) -{ - expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight", - "kappa", "lambda", "p", "epsilon", "max.iter") - for (name in colnames(df)) { - if (!(name %in% expected.colnames)) { - stop("Invalid header name supplied in parameter grid: ", name) - } - } - - conditions <- list( - p=function(x) { x >= 1.0 && x <= 2.0 }, - kappa=function(x) { x > -1.0 }, - lambda=function(x) {x > 0.0 }, - epsilon=function(x) { x > 0.0 }, - gamma=function(x) { x != 0.0 }, - weight=function(x) { x %in% c("unit", "group") }, - kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } - ) - - for (idx in 1:nrow(df)) { - for (param in colnames(df)) { - if (!(param %in% names(conditions))) - next - func <- conditions[[param]] - value <- df[[param]][idx] - if (!func(value)) - stop("Invalid value in grid for parameter: ", param) - } - } -} - gensvm.cv.results <- function(results, param.grid, cv.idx, y.true, scoring, return.train.score=TRUE) { diff --git a/R/validate.R b/R/validate.R new file mode 100644 index 0000000..b0f3f39 --- /dev/null +++ b/R/validate.R @@ -0,0 +1,70 @@ +#' @title [internal] Validate parameters +#' +#' @export +#' @keywords internal +gensvm.validate.params <- function(p=NULL, kappa=NULL, lambda=NULL, + epsilon=NULL, gamma=NULL, weights=NULL, + kernel=NULL, ...) +{ + the.args <- as.list(match.call()) + conditions <- gensvm.param.conditions() + for (param in names(the.args)) { + if (is.null(the.args[[param]])) + next + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- eval(the.args[[param]]) + if (!func(value)) { + cat(sprintf("Error: Parameter '%s' got invalid value: %s\n", param, + toString(value))) + return(FALSE) + } + } + return(TRUE) +} + +#' @title [internal] Validate parameter grid +#' +#' @export +#' @keywords internal +gensvm.validate.param.grid <- function(df) +{ + expected.colnames <- c("kernel", "coef", "degree", "gamma", "weights", + "kappa", "lambda", "p", "epsilon", "max.iter") + for (name in colnames(df)) { + if (!(name %in% expected.colnames)) { + cat(sprintf("Error: Invalid name supplied in parameter grid: %s\n", + name)) + return(FALSE) + } + } + + conditions <- gensvm.param.conditions() + for (idx in 1:nrow(df)) { + for (param in colnames(df)) { + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- df[[param]][idx] + if (!func(value)) { + cat(sprintf("Invalid value in grid for parameter: %s\n", param)) + return(FALSE) + } + } + } + return(TRUE) +} + +gensvm.param.conditions <- function() +{ + conditions <- list( + p=function(x) { x >= 1.0 && x <= 2.0 }, + kappa=function(x) { x > -1.0 }, + lambda=function(x) {x > 0.0 }, + epsilon=function(x) { x > 0.0 }, + gamma=function(x) { x != 0.0 }, + weights=function(x) { x %in% c("unit", "group") }, + kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } + ) +} |
