aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-04 15:06:33 -0400
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2018-04-04 15:06:33 -0400
commitde17a6d755e9369a91abdb06562ee93d7b323bbd (patch)
treecbede010da1046ee6627143aa0e41d8ec0fbb09e
parentAdd importFrom statements (diff)
downloadrgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.tar.gz
rgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.zip
Adhere to generic function signatures
-rw-r--r--R/coef.gensvm.R8
-rw-r--r--R/fitted.gensvm.R10
-rw-r--r--R/fitted.gensvm.grid.R9
-rw-r--r--R/plot.gensvm.R26
-rw-r--r--R/plot.gensvm.grid.R10
-rw-r--r--R/predict.gensvm.R38
-rw-r--r--R/predict.gensvm.grid.R12
-rw-r--r--R/print.gensvm.R58
-rw-r--r--R/print.gensvm.grid.R36
9 files changed, 106 insertions, 101 deletions
diff --git a/R/coef.gensvm.R b/R/coef.gensvm.R
index 8f60fd4..f0369ee 100644
--- a/R/coef.gensvm.R
+++ b/R/coef.gensvm.R
@@ -2,7 +2,7 @@
#'
#' @description Returns the model coefficients of the GenSVM object
#'
-#' @param fit a \code{gensvm} object
+#' @param object a \code{gensvm} object
#' @param \dots further arguments are ignored
#'
#' @return The coefficients of the GenSVM model. This is a matrix of size
@@ -38,10 +38,10 @@
#' fit <- gensvm(x, y)
#' V <- coef(fit)
#'
-coef.gensvm <- function(fit, ...)
+coef.gensvm <- function(object, ...)
{
- V <- fit$V
- x <- eval.parent(fit$call$x)
+ V <- object$V
+ x <- eval.parent(object$call$x)
name <- c("translation", colnames(x))
rownames(V) <- name
return(V)
diff --git a/R/fitted.gensvm.R b/R/fitted.gensvm.R
index 72b5db6..be97223 100644
--- a/R/fitted.gensvm.R
+++ b/R/fitted.gensvm.R
@@ -3,7 +3,7 @@
#' @description This function shows the fitted class labels of training data
#' using a fitted GenSVM model.
#'
-#' @param fit Fitted \code{gensvm} object
+#' @param object Fitted \code{gensvm} object
#' @param \dots further arguments are passed to predict
#'
#' @return a vector of class labels, with the same type as the original class
@@ -22,8 +22,9 @@
#' \code{\link{plot.gensvm}}, \code{\link{predict.gensvm.grid}},
#' \code{\link{gensvm}}, \code{\link{gensvm-package}}
#'
+#' @method fitted gensvm
+#'
#' @export
-#' @aliases fitted
#'
#' @examples
#' x <- iris[, -5]
@@ -36,7 +37,8 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y, yhat)
#'
-fitted.gensvm <- function(fit, ...)
+fitted.gensvm <- function(object, ...)
{
- return(predict(fit, ...))
+ x <- eval.parent(object$call$x)
+ return(predict(object, x, ...))
}
diff --git a/R/fitted.gensvm.grid.R b/R/fitted.gensvm.grid.R
index c6887c0..f24ec64 100644
--- a/R/fitted.gensvm.grid.R
+++ b/R/fitted.gensvm.grid.R
@@ -3,7 +3,7 @@
#' @description Wrapper to get the fitted class labels from the best estimator
#' of the fitted GenSVMGrid model. Only works if refit was enabled.
#'
-#' @param grid A \code{gensvm.grid} object
+#' @param object A \code{gensvm.grid} object
#' @param \dots further arguments are passed to fitted
#'
#' @return a vector of class labels, with the same type as the original class
@@ -22,8 +22,9 @@
#' \code{\link{plot.gensvm}}, \code{\link{predict.gensvm.grid}},
#' \code{\link{gensvm}}, \code{\link{gensvm-package}}
#'
+#' @method fitted gensvm.grid
+#'
#' @export
-#' @aliases fitted
#'
#' @examples
#' x <- iris[, -5]
@@ -36,7 +37,7 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y, yhat)
#'
-fitted.gensvm.grid <- function(grid, ...)
+fitted.gensvm.grid <- function(object, ...)
{
- return(predict(grid, ...))
+ return(predict(object, ...))
}
diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R
index d490c1e..129d90a 100644
--- a/R/plot.gensvm.R
+++ b/R/plot.gensvm.R
@@ -5,10 +5,11 @@
#' 3 classes. For more than 3 classes, the simplex space is too high
#' dimensional to easily visualize.
#'
-#' @param fit A fitted \code{gensvm} object
-#' @param x the dataset to plot (if NULL the training data is used)
-#' @param y the labels to color points with (if NULL the predicted labels are
-#' used)
+#' @param x A fitted \code{gensvm} object
+#' @param labels the labels to color points with. If this is omitted the
+#' predicted labels are used.
+#' @param newdata the dataset to plot. If this is NULL the training data is
+#' used.
#' @param with.margins plot the margins
#' @param with.shading show shaded areas for the class regions
#' @param with.legend show the legend for the class labels
@@ -62,19 +63,20 @@
#' plot(fit, x.test=x.mis)
#' plot(fit, y.mis.true, x.test=x.mis)
#'
-plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
+plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE,
with.shading=TRUE, with.legend=TRUE, center.plot=TRUE,
xlim=NULL, ylim=NULL, ...)
{
+ fit <- x
if (!(fit$n.classes %in% c(2,3))) {
cat("Error: Can only plot with 2 or 3 classes\n")
return(invisible(NULL))
}
- x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test
+ newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata
# Sanity check
- if (ncol(x.test) != fit$n.features) {
+ if (ncol(newdata) != fit$n.features) {
cat("Error: Number of features of fitted model and testing data
disagree.\n")
return(invisible(NULL))
@@ -83,20 +85,20 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
x.train <- eval.parent(fit$call$x)
if (fit$kernel == 'linear') {
V <- coef(fit)
- Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test))
+ Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata))
S <- Z %*% V
- y.pred.orig <- predict(fit, x.test)
+ y.pred.orig <- predict(fit, newdata)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
kernel.idx <- which(kernels == fit$kernel) - 1
plotdata <- .Call("R_gensvm_plotdata_kernels",
- as.matrix(x.test),
+ as.matrix(newdata),
as.matrix(x.train),
as.matrix(fit$V),
as.integer(nrow(fit$V)),
as.integer(ncol(fit$V)),
as.integer(nrow(x.train)),
- as.integer(nrow(x.test)),
+ as.integer(nrow(newdata)),
as.integer(fit$n.features),
as.integer(fit$n.classes),
as.integer(kernel.idx),
@@ -116,7 +118,7 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE,
y.pred <- y.pred.orig
}
- labels <- if(missing(y)) y.pred else match(y, classes)
+ labels <- if(missing(labels)) y.pred else match(labels, classes)
colors <- gensvm.plot.colors(fit$n.classes)
markers <- gensvm.plot.markers(fit$n.classes)
diff --git a/R/plot.gensvm.grid.R b/R/plot.gensvm.grid.R
index abb0601..6a34024 100644
--- a/R/plot.gensvm.grid.R
+++ b/R/plot.gensvm.grid.R
@@ -4,7 +4,7 @@
#' model in the provided GenSVMGrid object. See the documentation for
#' \code{\link{plot.gensvm}} for more information.
#'
-#' @param grid A \code{gensvm.grid} object trained with refit=TRUE
+#' @param x A \code{gensvm.grid} object trained with refit=TRUE
#' @param ... further arguments are passed to the plot function
#'
#' @return returns the object passed as input
@@ -31,12 +31,12 @@
#' grid <- gensvm.grid(x, y)
#' plot(grid, x)
#'
-plot.gensvm.grid <- function(grid, ...)
+plot.gensvm.grid <- function(x, ...)
{
- if (is.null(grid$best.estimator)) {
+ if (is.null(x$best.estimator)) {
cat("Error: Can't plot, the best.estimator element is NULL\n")
- return
+ return(invisible(NULL))
}
- fit <- grid$best.estimator
+ fit <- x$best.estimator
return(plot(fit, ...))
}
diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R
index 0b5fa91..43a0d52 100644
--- a/R/predict.gensvm.R
+++ b/R/predict.gensvm.R
@@ -3,7 +3,7 @@
#' @description This function predicts the class labels of new data using a
#' fitted GenSVM model.
#'
-#' @param fit Fitted \code{gensvm} object
+#' @param object Fitted \code{gensvm} object
#' @param newdata Matrix of new data for which predictions need to be made.
#' @param add.rownames add the rownames from the training data to the
#' predictions
@@ -45,7 +45,7 @@
#' # compute the accuracy with gensvm.accuracy
#' gensvm.accuracy(y.test, y.test.pred)
#'
-predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
+predict.gensvm <- function(object, newdata, add.rownames=FALSE, ...)
{
## Implementation note:
## - It might seem that it would be faster to do the prediction directly in
@@ -55,48 +55,48 @@ predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...)
## the C implementation is *much* faster than doing it in R.
if (missing(newdata)) {
- newdata <- eval.parent(fit$call$x)
+ newdata <- eval.parent(object$call$x)
}
x.test <- as.matrix(newdata)
# Sanity check
- if (ncol(x.test) != fit$n.features) {
+ if (ncol(x.test) != object$n.features) {
cat("Error: Number of features of fitted model and testing",
"data disagree.\n")
return(invisible(NULL))
}
- x.train <- eval.parent(fit$call$x)
- if (fit$kernel == 'linear') {
+ x.train <- eval.parent(object$call$x)
+ if (object$kernel == 'linear') {
y.pred.c <- .Call("R_gensvm_predict",
as.matrix(x.test),
- as.matrix(fit$V),
+ as.matrix(object$V),
as.integer(nrow(x.test)),
as.integer(ncol(x.test)),
- as.integer(fit$n.classes)
+ as.integer(object$n.classes)
)
} else {
kernels <- c("linear", "poly", "rbf", "sigmoid")
- kernel.idx <- which(kernels == fit$kernel) - 1
+ kernel.idx <- which(kernels == object$kernel) - 1
y.pred.c <- .Call("R_gensvm_predict_kernels",
as.matrix(x.test),
as.matrix(x.train),
- as.matrix(fit$V),
- as.integer(nrow(fit$V)),
- as.integer(ncol(fit$V)),
+ as.matrix(object$V),
+ as.integer(nrow(object$V)),
+ as.integer(ncol(object$V)),
as.integer(nrow(x.train)),
as.integer(nrow(x.test)),
- as.integer(fit$n.features),
- as.integer(fit$n.classes),
+ as.integer(object$n.features),
+ as.integer(object$n.classes),
as.integer(kernel.idx),
- fit$gamma,
- fit$coef,
- fit$degree,
- fit$kernel.eigen.cutoff
+ object$gamma,
+ object$coef,
+ object$degree,
+ object$kernel.eigen.cutoff
)
}
- yhat <- fit$classes[y.pred.c]
+ yhat <- object$classes[y.pred.c]
if (add.rownames) {
yhat <- matrix(yhat, length(yhat), 1)
diff --git a/R/predict.gensvm.grid.R b/R/predict.gensvm.grid.R
index df48d76..acc838f 100644
--- a/R/predict.gensvm.grid.R
+++ b/R/predict.gensvm.grid.R
@@ -7,9 +7,9 @@
#' this model is only available if \code{refit=TRUE} was specified in the
#' \code{\link{gensvm.grid}} call (the default).
#'
-#' @param grid A \code{gensvm.grid} object trained with \code{refit=TRUE}
-#' @param newx Matrix of new values for \code{x} for which predictions need to
-#' be computed.
+#' @param object A \code{gensvm.grid} object trained with \code{refit=TRUE}
+#' @param newdata Matrix of new values for \code{x} for which predictions need
+#' to be computed.
#' @param \dots further arguments are passed to predict.gensvm()
#'
#' @return a vector of class labels, with the same type as the original class
@@ -44,12 +44,12 @@
#' # predict training sample
#' y.hat <- predict(grid, x)
#'
-predict.gensvm.grid <- function(grid, newx, ...)
+predict.gensvm.grid <- function(object, newdata, ...)
{
- if (is.null(grid$best.estimator)) {
+ if (is.null(object$best.estimator)) {
cat("Error: Can't predict, the best.estimator element is NULL\n")
return
}
- return(predict(grid$best.estimator, newx, ...))
+ return(predict(object$best.estimator, newdata, ...))
}
diff --git a/R/print.gensvm.R b/R/print.gensvm.R
index 724806b..54397da 100644
--- a/R/print.gensvm.R
+++ b/R/print.gensvm.R
@@ -2,7 +2,7 @@
#'
#' @description Prints a short description of the fitted GenSVM model
#'
-#' @param fit A \code{gensvm} object to print
+#' @param x A \code{gensvm} object to print
#' @param \dots further arguments are ignored
#'
#' @return returns the object passed as input. This can be useful for chaining
@@ -36,41 +36,41 @@
#' fit <- gensvm(x, y)
#' predict(print(fit), x)
#'
-print.gensvm <- function(fit, ...)
+print.gensvm <- function(x, ...)
{
cat("Data:\n")
- cat("\tn.objects:", fit$n.objects, "\n")
- cat("\tn.features:", fit$n.features, "\n")
- cat("\tn.classes:", fit$n.classes, "\n")
- if (is.factor(fit$classes))
- cat("\tclasses:", levels(fit$classes), "\n")
+ cat("\tn.objects:", x$n.objects, "\n")
+ cat("\tn.features:", x$n.features, "\n")
+ cat("\tn.classes:", x$n.classes, "\n")
+ if (is.factor(x$classes))
+ cat("\tclasses:", levels(x$classes), "\n")
else
- cat("\tclasses:", fit$classes, "\n")
+ cat("\tclasses:", x$classes, "\n")
cat("Parameters:\n")
- cat("\tp:", fit$p, "\n")
- cat("\tlambda:", fit$lambda, "\n")
- cat("\tkappa:", fit$kappa, "\n")
- cat("\tepsilon:", fit$epsilon, "\n")
- cat("\tweights:", fit$weights, "\n")
- cat("\tmax.iter:", fit$max.iter, "\n")
- cat("\trandom.seed:", fit$random.seed, "\n")
- if (is.factor(fit$kernel)) {
- cat("\tkernel:", levels(fit$kernel)[as.numeric(fit$kernel)], "\n")
+ cat("\tp:", x$p, "\n")
+ cat("\tlambda:", x$lambda, "\n")
+ cat("\tkappa:", x$kappa, "\n")
+ cat("\tepsilon:", x$epsilon, "\n")
+ cat("\tweights:", x$weights, "\n")
+ cat("\tmax.iter:", x$max.iter, "\n")
+ cat("\trandom.seed:", x$random.seed, "\n")
+ if (is.factor(x$kernel)) {
+ cat("\tkernel:", levels(x$kernel)[as.numeric(x$kernel)], "\n")
} else {
- cat("\tkernel:", fit$kernel, "\n")
+ cat("\tkernel:", x$kernel, "\n")
}
- if (fit$kernel %in% c("poly", "rbf", "sigmoid")) {
- cat("\tkernel.eigen.cutoff:", fit$kernel.eigen.cutoff, "\n")
- cat("\tgamma:", fit$gamma, "\n")
+ if (x$kernel %in% c("poly", "rbf", "sigmoid")) {
+ cat("\tkernel.eigen.cutoff:", x$kernel.eigen.cutoff, "\n")
+ cat("\tgamma:", x$gamma, "\n")
}
- if (fit$kernel %in% c("poly", "sigmoid"))
- cat("\tcoef:", fit$coef, "\n")
- if (fit$kernel == 'poly')
- cat("\tdegree:", fit$degree, "\n")
+ if (x$kernel %in% c("poly", "sigmoid"))
+ cat("\tcoef:", x$coef, "\n")
+ if (x$kernel == 'poly')
+ cat("\tdegree:", x$degree, "\n")
cat("Results:\n")
- cat("\ttime:", fit$training.time, "\n")
- cat("\tn.iter:", fit$n.iter, "\n")
- cat("\tn.support:", fit$n.support, "\n")
+ cat("\ttime:", x$training.time, "\n")
+ cat("\tn.iter:", x$n.iter, "\n")
+ cat("\tn.support:", x$n.support, "\n")
- invisible(fit)
+ invisible(x)
}
diff --git a/R/print.gensvm.grid.R b/R/print.gensvm.grid.R
index 5e4c5da..3e1bb69 100644
--- a/R/print.gensvm.grid.R
+++ b/R/print.gensvm.grid.R
@@ -2,7 +2,7 @@
#'
#' @description Prints the summary of the fitted GenSVMGrid model
#'
-#' @param grid a \code{gensvm.grid} object to print
+#' @param x a \code{gensvm.grid} object to print
#' @param \dots further arguments are ignored
#'
#' @return returns the object passed as input
@@ -32,38 +32,38 @@
#' grid <- gensvm.grid(x, y)
#' print(grid)
#'
-print.gensvm.grid <- function(grid, ...)
+print.gensvm.grid <- function(x, ...)
{
cat("Data:\n")
- cat("\tn.objects:", grid$n.objects, "\n")
- cat("\tn.features:", grid$n.features, "\n")
- cat("\tn.classes:", grid$n.classes, "\n")
- if (is.factor(grid$classes))
- cat("\tclasses:", levels(grid$classes), "\n")
+ cat("\tn.objects:", x$n.objects, "\n")
+ cat("\tn.features:", x$n.features, "\n")
+ cat("\tn.classes:", x$n.classes, "\n")
+ if (is.factor(x$classes))
+ cat("\tclasses:", levels(x$classes), "\n")
else
- cat("\tclasses:", grid$classes, "\n")
+ cat("\tclasses:", x$classes, "\n")
cat("Config:\n")
- cat("\tNumber of cv splits:", grid$n.splits, "\n")
- not.run <- sum(is.na(grid$cv.results$rank.test.score))
+ cat("\tNumber of cv splits:", x$n.splits, "\n")
+ not.run <- sum(is.na(x$cv.results$rank.test.score))
if (not.run > 0) {
- cat("\tParameter grid size:", dim(grid$param.grid)[1])
+ cat("\tParameter grid size:", dim(x$param.grid)[1])
cat(" (", not.run, " incomplete)", sep="")
cat("\n")
} else {
- cat("\tParameter grid size:", dim(grid$param.grid)[1], "\n")
+ cat("\tParameter grid size:", dim(x$param.grid)[1], "\n")
}
cat("Results:\n")
- cat("\tTotal grid search time:", grid$total.time, "\n")
- if (!is.na(grid$best.index)) {
- best <- grid$cv.results[grid$best.index, ]
+ cat("\tTotal grid search time:", x$total.time, "\n")
+ if (!is.na(x$best.index)) {
+ best <- x$cv.results[x$best.index, ]
cat("\tBest mean test score:", best$mean.test.score, "\n")
cat("\tBest mean fit time:", best$mean.fit.time, "\n")
- for (name in colnames(grid$best.params)) {
- val <- grid$best.params[[name]]
+ for (name in colnames(x$best.params)) {
+ val <- x$best.params[[name]]
val <- if(is.factor(val)) levels(val)[val] else val
cat("\tBest parameter", name, "=", val, "\n")
}
}
- invisible(grid)
+ invisible(x)
}