aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-02-23 17:10:16 +0000
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-02-23 17:10:16 +0000
commitbdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468 (patch)
treedcf17ae3e5fef6cee367f0da985bb4894495aee9 /R
parentupdate gensvm C library (diff)
downloadrgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.tar.gz
rgensvm-bdeeb59f8f64a9c3a2e083f7e5c33a4c30a2c468.zip
Implement fitting and prediction
Diffstat (limited to 'R')
-rw-r--r--R/gensvm.R46
-rw-r--r--R/predict.gensvm.R25
-rw-r--r--R/print.gensvm.R28
-rw-r--r--R/util.labelencoder.R1
4 files changed, 73 insertions, 27 deletions
diff --git a/R/gensvm.R b/R/gensvm.R
index 1923f06..4a4ab6b 100644
--- a/R/gensvm.R
+++ b/R/gensvm.R
@@ -56,9 +56,9 @@
#' \code{\link{plot}}, and \code{\link{gensvm.grid}}.
#'
#' @export
+#' @useDynLib gensvm_wrapper, .registration = TRUE
#'
#' @examples
-#' X <-
#'
gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
weights='unit', kernel='linear', gamma='auto', coef=0.0,
@@ -67,10 +67,10 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
{
call <- match.call()
-
- # TODO: generate the random.seed value in R if it is NULL. Then you can
- # return it and people can still reproduce even if they forgot to set it
- # explicitly.
+ # Generate the random.seed value in R if it is NULL. This way users can
+ # reproduce the run because it is returned in the output object.
+ if (is.null(random.seed))
+ random.seed <- runif(1) * (2**31 - 1)
# TODO: Store a labelencoder in the object, preferably as a partially
# hidden item. This can then be used with prediction.
@@ -79,9 +79,13 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
n.features <- ncol(X)
n.classes <- length(unique(y))
-
# Convert labels to integers
- y.clean <- label.encode(y)
+ classes <- sort(unique(y))
+ y.clean <- match(y, classes)
+
+ # Convert gamma if it is 'auto'
+ if (gamma == 'auto')
+ gamma <- 1.0/n.features
# Convert weights to index
weight.idx <- which(c("unit", "group") == weights)
@@ -90,39 +94,43 @@ gensvm <- function(X, y, p=1.0, lambda=1e-5, kappa=0.0, epsilon=1e-6,
"Valid options are 'unit' and 'group'")
}
- # Convert kernel to index
- kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel)
+ # Convert kernel to index (remember off-by-one for R vs. C)
+ kernel.idx <- which(c("linear", "poly", "rbf", "sigmoid") == kernel) - 1
if (length(kernel.idx) == 0) {
stop("Incorrect kernel specification. ",
"Valid options are 'linear', 'poly', 'rbf', and 'sigmoid'")
}
-
out <- .Call("R_gensvm_train",
- as.matrix(t(X)),
+ as.matrix(X),
as.integer(y.clean),
p,
lambda,
kappa,
epsilon,
weight.idx,
- kernel.idx,
+ as.integer(kernel.idx),
gamma,
coef,
degree,
kernel.eigen.cutoff,
- verbose,
- max.iter,
- random.seed,
- seed.V)
-
+ as.integer(verbose),
+ as.integer(max.iter),
+ as.integer(random.seed),
+ seed.V,
+ as.integer(n.objects),
+ as.integer(n.features),
+ as.integer(n.classes))
- object <- list(call = call, lambda = lambda, kappa = kappa,
+ object <- list(call = call, p = p, lambda = lambda, kappa = kappa,
epsilon = epsilon, weights = weights, kernel = kernel,
gamma = gamma, coef = coef, degree = degree,
kernel.eigen.cutoff = kernel.eigen.cutoff,
random.seed = random.seed, max.iter = max.iter,
- V = out$V, n.iter = out$n.iter, n.support = out$n.support)
+ n.objects = n.objects, n.features = n.features,
+ n.classes = n.classes, classes = classes, V = out$V,
+ n.iter = out$n.iter, n.support = out$n.support)
class(object) <- "gensvm"
+
return(object)
}
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 6cc8851..5c8f2e7 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -3,7 +3,7 @@
#' @description This function predicts the class labels of new data using a
#' fitted GenSVM model.
#'
-#' @param object Fitted \code{gensvm} object
+#' @param fit Fitted \code{gensvm} object
#' @param newx Matrix of new values for \code{x} for which predictions need to
#' be made.
#' @param \dots further arguments are ignored
@@ -27,10 +27,27 @@
#'
#'
#'
-predict.gensvm <- function(object, newx, ...)
+predict.gensvm <- function(fit, newx, ...)
{
- # TODO: C library fitting prediction here (or not? with the column-major
- # order it may be faster to do it directly in R)
+ ## Implementation note:
+ ## - It might seem that it would be faster to do the prediction directly in
+ ## R here, since we then don't need to switch to C, construct model and
+ ## data structures, copy the data, etc. before doing the prediction.
+ ## However, if you actually implement it and compare, we find that calling
+ ## the C implementation is *much* faster than doing it in R.
+
+ # Sanity check
+ if (ncol(newx) != fit$n.features)
+ stop("Number of features of fitted model and supplied data disagree.")
+
+ y.pred.c <- .Call("R_gensvm_predict",
+ as.matrix(newx),
+ as.matrix(fit$V),
+ as.integer(nrow(newx)),
+ as.integer(ncol(newx)),
+ as.integer(fit$n.classes)
+ )
+ yhat <- fit$classes[y.pred.c]
return(yhat)
}
diff --git a/R/print.gensvm.R b/R/print.gensvm.R
index 8d17b0c..06a3649 100644
--- a/R/print.gensvm.R
+++ b/R/print.gensvm.R
@@ -26,9 +26,31 @@ print.gensvm <- function(object, ...)
{
cat("\nCall:\n")
dput(object$call)
+ cat("\nData:\n")
+ cat("\tn.objects:", object$n.objects, "\n")
+ cat("\tn.features:", object$n.features, "\n")
+ cat("\tn.classes:", object$n.classes, "\n")
+ cat("\tclasses:", object$classes, "\n")
+ cat("Parameters:\n")
+ cat("\tp:", object$p, "\n")
+ cat("\tlambda:", object$lambda, "\n")
+ cat("\tkappa:", object$kappa, "\n")
+ cat("\tepsilon:", object$epsilon, "\n")
+ cat("\tweights:", object$weights, "\n")
+ cat("\tmax.iter:", object$max.iter, "\n")
+ cat("\trandom.seed:", object$random.seed, "\n")
+ cat("\tkernel:", object$kernel, "\n")
+ if (object$kernel %in% c("poly", "rbf", "sigmoid")) {
+ cat("\tkernel.eigen.cutoff:", object$kernel.eigen.cutoff, "\n")
+ cat("\tgamma:", object$gamma, "\n")
+ }
+ if (object$kernel %in% c("poly", "sigmoid"))
+ cat("\tcoef:", object$coef, "\n")
+ if (object$kernel == 'poly')
+ cat("\tdegree:", object$degree, "\n")
+ cat("Results:\n")
+ cat("\tn.iter:", object$n.iter, "\n")
+ cat("\tn.support:", object$n.support, "\n")
- # TODO: fill this out
- #
- #
invisible(object)
}
diff --git a/R/util.labelencoder.R b/R/util.labelencoder.R
deleted file mode 100644
index 8b13789..0000000
--- a/R/util.labelencoder.R
+++ /dev/null
@@ -1 +0,0 @@
-