diff options
| -rw-r--r-- | NAMESPACE | 2 | ||||
| -rw-r--r-- | R/gensvm.R | 15 | ||||
| -rw-r--r-- | R/gensvm.grid.R | 35 | ||||
| -rw-r--r-- | R/validate.R | 70 | ||||
| -rw-r--r-- | man/gensvm.validate.param.grid.Rd | 10 | ||||
| -rw-r--r-- | man/gensvm.validate.params.Rd | 11 |
6 files changed, 100 insertions, 43 deletions
@@ -17,4 +17,6 @@ export(gensvm.load.tiny.grid) export(gensvm.maxabs.scale) export(gensvm.refit) export(gensvm.train.test.split) +export(gensvm.validate.param.grid) +export(gensvm.validate.params) useDynLib(gensvm_wrapper, .registration = TRUE) @@ -133,21 +133,16 @@ gensvm <- function(X, y, p=1.0, lambda=1e-8, kappa=0.0, epsilon=1e-6, if (gamma == 'auto') gamma <- 1.0/n.features + if (!gensvm.validate.params(p=p, kappa=kappa, lambda=lambda, + epsilon=epsilon, gamma=gamma, weights=weights, + kernel=kernel)) + return(NULL) + # Convert weights to index weight.idx <- which(c("unit", "group") == weights) - if (length(weight.idx) == 0) { - cat("Error: Incorrect weight specification. ", - "Valid options are 'unit' and 'group'") - return - } # Convert kernel to index (remember off-by-one for R vs. C) kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1 - if (length(kernel.idx) == 0) { - cat("Error: Incorrect kernel specification. ", - "Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'") - return - } seed.rows <- if(is.null(seed.V)) -1 else nrow(seed.V) seed.cols <- if(is.null(seed.V)) -1 else ncol(seed.V) diff --git a/R/gensvm.grid.R b/R/gensvm.grid.R index 2acdc2f..f954b79 100644 --- a/R/gensvm.grid.R +++ b/R/gensvm.grid.R @@ -162,7 +162,8 @@ gensvm.grid <- function(X, y, param.grid='tiny', refit=TRUE, scoring=NULL, cv=3, } # Validate the range of the values for the gridsearch - gensvm.validate.param.grid(param.grid) + if (!gensvm.validate.param.grid(param.grid)) + return(NULL) # Sort the parameter grid for efficient warm starts param.grid <- gensvm.sort.param.grid(param.grid) @@ -379,38 +380,6 @@ gensvm.generate.cv.idx <- function(n, folds) return(cv.idx) } -gensvm.validate.param.grid <- function(df) -{ - expected.colnames <- c("kernel", "coef", "degree", "gamma", "weight", - "kappa", "lambda", "p", "epsilon", "max.iter") - for (name in colnames(df)) { - if (!(name %in% expected.colnames)) { - stop("Invalid header name supplied in parameter grid: ", name) - } - } - - conditions <- list( - p=function(x) { x >= 1.0 && x <= 2.0 }, - kappa=function(x) { x > -1.0 }, - lambda=function(x) {x > 0.0 }, - epsilon=function(x) { x > 0.0 }, - gamma=function(x) { x != 0.0 }, - weight=function(x) { x %in% c("unit", "group") }, - kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } - ) - - for (idx in 1:nrow(df)) { - for (param in colnames(df)) { - if (!(param %in% names(conditions))) - next - func <- conditions[[param]] - value <- df[[param]][idx] - if (!func(value)) - stop("Invalid value in grid for parameter: ", param) - } - } -} - gensvm.cv.results <- function(results, param.grid, cv.idx, y.true, scoring, return.train.score=TRUE) { diff --git a/R/validate.R b/R/validate.R new file mode 100644 index 0000000..b0f3f39 --- /dev/null +++ b/R/validate.R @@ -0,0 +1,70 @@ +#' @title [internal] Validate parameters +#' +#' @export +#' @keywords internal +gensvm.validate.params <- function(p=NULL, kappa=NULL, lambda=NULL, + epsilon=NULL, gamma=NULL, weights=NULL, + kernel=NULL, ...) +{ + the.args <- as.list(match.call()) + conditions <- gensvm.param.conditions() + for (param in names(the.args)) { + if (is.null(the.args[[param]])) + next + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- eval(the.args[[param]]) + if (!func(value)) { + cat(sprintf("Error: Parameter '%s' got invalid value: %s\n", param, + toString(value))) + return(FALSE) + } + } + return(TRUE) +} + +#' @title [internal] Validate parameter grid +#' +#' @export +#' @keywords internal +gensvm.validate.param.grid <- function(df) +{ + expected.colnames <- c("kernel", "coef", "degree", "gamma", "weights", + "kappa", "lambda", "p", "epsilon", "max.iter") + for (name in colnames(df)) { + if (!(name %in% expected.colnames)) { + cat(sprintf("Error: Invalid name supplied in parameter grid: %s\n", + name)) + return(FALSE) + } + } + + conditions <- gensvm.param.conditions() + for (idx in 1:nrow(df)) { + for (param in colnames(df)) { + if (!(param %in% names(conditions))) + next + func <- conditions[[param]] + value <- df[[param]][idx] + if (!func(value)) { + cat(sprintf("Invalid value in grid for parameter: %s\n", param)) + return(FALSE) + } + } + } + return(TRUE) +} + +gensvm.param.conditions <- function() +{ + conditions <- list( + p=function(x) { x >= 1.0 && x <= 2.0 }, + kappa=function(x) { x > -1.0 }, + lambda=function(x) {x > 0.0 }, + epsilon=function(x) { x > 0.0 }, + gamma=function(x) { x != 0.0 }, + weights=function(x) { x %in% c("unit", "group") }, + kernel=function(x) { x %in% c("linear", "poly", "rbf", "sigmoid") } + ) +} diff --git a/man/gensvm.validate.param.grid.Rd b/man/gensvm.validate.param.grid.Rd new file mode 100644 index 0000000..5528f16 --- /dev/null +++ b/man/gensvm.validate.param.grid.Rd @@ -0,0 +1,10 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/validate.R +\name{gensvm.validate.param.grid} +\alias{gensvm.validate.param.grid} +\title{[internal] Validate parameter grid} +\usage{ +gensvm.validate.param.grid(df) +} +\keyword{internal} + diff --git a/man/gensvm.validate.params.Rd b/man/gensvm.validate.params.Rd new file mode 100644 index 0000000..1d249a9 --- /dev/null +++ b/man/gensvm.validate.params.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/validate.R +\name{gensvm.validate.params} +\alias{gensvm.validate.params} +\title{[internal] Validate parameters} +\usage{ +gensvm.validate.params(p = NULL, kappa = NULL, lambda = NULL, + epsilon = NULL, gamma = NULL, weights = NULL, kernel = NULL, ...) +} +\keyword{internal} + |
