diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-04-04 15:06:33 -0400 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2018-04-04 15:06:33 -0400 |
| commit | de17a6d755e9369a91abdb06562ee93d7b323bbd (patch) | |
| tree | cbede010da1046ee6627143aa0e41d8ec0fbb09e | |
| parent | Add importFrom statements (diff) | |
| download | rgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.tar.gz rgensvm-de17a6d755e9369a91abdb06562ee93d7b323bbd.zip | |
Adhere to generic function signatures
| -rw-r--r-- | R/coef.gensvm.R | 8 | ||||
| -rw-r--r-- | R/fitted.gensvm.R | 10 | ||||
| -rw-r--r-- | R/fitted.gensvm.grid.R | 9 | ||||
| -rw-r--r-- | R/plot.gensvm.R | 26 | ||||
| -rw-r--r-- | R/plot.gensvm.grid.R | 10 | ||||
| -rw-r--r-- | R/predict.gensvm.R | 38 | ||||
| -rw-r--r-- | R/predict.gensvm.grid.R | 12 | ||||
| -rw-r--r-- | R/print.gensvm.R | 58 | ||||
| -rw-r--r-- | R/print.gensvm.grid.R | 36 |
9 files changed, 106 insertions, 101 deletions
diff --git a/R/coef.gensvm.R b/R/coef.gensvm.R index 8f60fd4..f0369ee 100644 --- a/R/coef.gensvm.R +++ b/R/coef.gensvm.R @@ -2,7 +2,7 @@ #' #' @description Returns the model coefficients of the GenSVM object #' -#' @param fit a \code{gensvm} object +#' @param object a \code{gensvm} object #' @param \dots further arguments are ignored #' #' @return The coefficients of the GenSVM model. This is a matrix of size @@ -38,10 +38,10 @@ #' fit <- gensvm(x, y) #' V <- coef(fit) #' -coef.gensvm <- function(fit, ...) +coef.gensvm <- function(object, ...) { - V <- fit$V - x <- eval.parent(fit$call$x) + V <- object$V + x <- eval.parent(object$call$x) name <- c("translation", colnames(x)) rownames(V) <- name return(V) diff --git a/R/fitted.gensvm.R b/R/fitted.gensvm.R index 72b5db6..be97223 100644 --- a/R/fitted.gensvm.R +++ b/R/fitted.gensvm.R @@ -3,7 +3,7 @@ #' @description This function shows the fitted class labels of training data #' using a fitted GenSVM model. #' -#' @param fit Fitted \code{gensvm} object +#' @param object Fitted \code{gensvm} object #' @param \dots further arguments are passed to predict #' #' @return a vector of class labels, with the same type as the original class @@ -22,8 +22,9 @@ #' \code{\link{plot.gensvm}}, \code{\link{predict.gensvm.grid}}, #' \code{\link{gensvm}}, \code{\link{gensvm-package}} #' +#' @method fitted gensvm +#' #' @export -#' @aliases fitted #' #' @examples #' x <- iris[, -5] @@ -36,7 +37,8 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y, yhat) #' -fitted.gensvm <- function(fit, ...) +fitted.gensvm <- function(object, ...) { - return(predict(fit, ...)) + x <- eval.parent(object$call$x) + return(predict(object, x, ...)) } diff --git a/R/fitted.gensvm.grid.R b/R/fitted.gensvm.grid.R index c6887c0..f24ec64 100644 --- a/R/fitted.gensvm.grid.R +++ b/R/fitted.gensvm.grid.R @@ -3,7 +3,7 @@ #' @description Wrapper to get the fitted class labels from the best estimator #' of the fitted GenSVMGrid model. Only works if refit was enabled. #' -#' @param grid A \code{gensvm.grid} object +#' @param object A \code{gensvm.grid} object #' @param \dots further arguments are passed to fitted #' #' @return a vector of class labels, with the same type as the original class @@ -22,8 +22,9 @@ #' \code{\link{plot.gensvm}}, \code{\link{predict.gensvm.grid}}, #' \code{\link{gensvm}}, \code{\link{gensvm-package}} #' +#' @method fitted gensvm.grid +#' #' @export -#' @aliases fitted #' #' @examples #' x <- iris[, -5] @@ -36,7 +37,7 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y, yhat) #' -fitted.gensvm.grid <- function(grid, ...) +fitted.gensvm.grid <- function(object, ...) { - return(predict(grid, ...)) + return(predict(object, ...)) } diff --git a/R/plot.gensvm.R b/R/plot.gensvm.R index d490c1e..129d90a 100644 --- a/R/plot.gensvm.R +++ b/R/plot.gensvm.R @@ -5,10 +5,11 @@ #' 3 classes. For more than 3 classes, the simplex space is too high #' dimensional to easily visualize. #' -#' @param fit A fitted \code{gensvm} object -#' @param x the dataset to plot (if NULL the training data is used) -#' @param y the labels to color points with (if NULL the predicted labels are -#' used) +#' @param x A fitted \code{gensvm} object +#' @param labels the labels to color points with. If this is omitted the +#' predicted labels are used. +#' @param newdata the dataset to plot. If this is NULL the training data is +#' used. #' @param with.margins plot the margins #' @param with.shading show shaded areas for the class regions #' @param with.legend show the legend for the class labels @@ -62,19 +63,20 @@ #' plot(fit, x.test=x.mis) #' plot(fit, y.mis.true, x.test=x.mis) #' -plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, +plot.gensvm <- function(x, labels, newdata=NULL, with.margins=TRUE, with.shading=TRUE, with.legend=TRUE, center.plot=TRUE, xlim=NULL, ylim=NULL, ...) { + fit <- x if (!(fit$n.classes %in% c(2,3))) { cat("Error: Can only plot with 2 or 3 classes\n") return(invisible(NULL)) } - x.test <- if(is.null(x.test)) eval.parent(fit$call$x) else x.test + newdata <- if(is.null(newdata)) eval.parent(fit$call$x) else newdata # Sanity check - if (ncol(x.test) != fit$n.features) { + if (ncol(newdata) != fit$n.features) { cat("Error: Number of features of fitted model and testing data disagree.\n") return(invisible(NULL)) @@ -83,20 +85,20 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, x.train <- eval.parent(fit$call$x) if (fit$kernel == 'linear') { V <- coef(fit) - Z <- cbind(matrix(1, dim(x.test)[1], 1), as.matrix(x.test)) + Z <- cbind(matrix(1, dim(newdata)[1], 1), as.matrix(newdata)) S <- Z %*% V - y.pred.orig <- predict(fit, x.test) + y.pred.orig <- predict(fit, newdata) } else { kernels <- c("linear", "poly", "rbf", "sigmoid") kernel.idx <- which(kernels == fit$kernel) - 1 plotdata <- .Call("R_gensvm_plotdata_kernels", - as.matrix(x.test), + as.matrix(newdata), as.matrix(x.train), as.matrix(fit$V), as.integer(nrow(fit$V)), as.integer(ncol(fit$V)), as.integer(nrow(x.train)), - as.integer(nrow(x.test)), + as.integer(nrow(newdata)), as.integer(fit$n.features), as.integer(fit$n.classes), as.integer(kernel.idx), @@ -116,7 +118,7 @@ plot.gensvm <- function(fit, y, x.test=NULL, with.margins=TRUE, y.pred <- y.pred.orig } - labels <- if(missing(y)) y.pred else match(y, classes) + labels <- if(missing(labels)) y.pred else match(labels, classes) colors <- gensvm.plot.colors(fit$n.classes) markers <- gensvm.plot.markers(fit$n.classes) diff --git a/R/plot.gensvm.grid.R b/R/plot.gensvm.grid.R index abb0601..6a34024 100644 --- a/R/plot.gensvm.grid.R +++ b/R/plot.gensvm.grid.R @@ -4,7 +4,7 @@ #' model in the provided GenSVMGrid object. See the documentation for #' \code{\link{plot.gensvm}} for more information. #' -#' @param grid A \code{gensvm.grid} object trained with refit=TRUE +#' @param x A \code{gensvm.grid} object trained with refit=TRUE #' @param ... further arguments are passed to the plot function #' #' @return returns the object passed as input @@ -31,12 +31,12 @@ #' grid <- gensvm.grid(x, y) #' plot(grid, x) #' -plot.gensvm.grid <- function(grid, ...) +plot.gensvm.grid <- function(x, ...) { - if (is.null(grid$best.estimator)) { + if (is.null(x$best.estimator)) { cat("Error: Can't plot, the best.estimator element is NULL\n") - return + return(invisible(NULL)) } - fit <- grid$best.estimator + fit <- x$best.estimator return(plot(fit, ...)) } diff --git a/R/predict.gensvm.R b/R/predict.gensvm.R index 0b5fa91..43a0d52 100644 --- a/R/predict.gensvm.R +++ b/R/predict.gensvm.R @@ -3,7 +3,7 @@ #' @description This function predicts the class labels of new data using a #' fitted GenSVM model. #' -#' @param fit Fitted \code{gensvm} object +#' @param object Fitted \code{gensvm} object #' @param newdata Matrix of new data for which predictions need to be made. #' @param add.rownames add the rownames from the training data to the #' predictions @@ -45,7 +45,7 @@ #' # compute the accuracy with gensvm.accuracy #' gensvm.accuracy(y.test, y.test.pred) #' -predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) +predict.gensvm <- function(object, newdata, add.rownames=FALSE, ...) { ## Implementation note: ## - It might seem that it would be faster to do the prediction directly in @@ -55,48 +55,48 @@ predict.gensvm <- function(fit, newdata, add.rownames=FALSE, ...) ## the C implementation is *much* faster than doing it in R. if (missing(newdata)) { - newdata <- eval.parent(fit$call$x) + newdata <- eval.parent(object$call$x) } x.test <- as.matrix(newdata) # Sanity check - if (ncol(x.test) != fit$n.features) { + if (ncol(x.test) != object$n.features) { cat("Error: Number of features of fitted model and testing", "data disagree.\n") return(invisible(NULL)) } - x.train <- eval.parent(fit$call$x) - if (fit$kernel == 'linear') { + x.train <- eval.parent(object$call$x) + if (object$kernel == 'linear') { y.pred.c <- .Call("R_gensvm_predict", as.matrix(x.test), - as.matrix(fit$V), + as.matrix(object$V), as.integer(nrow(x.test)), as.integer(ncol(x.test)), - as.integer(fit$n.classes) + as.integer(object$n.classes) ) } else { kernels <- c("linear", "poly", "rbf", "sigmoid") - kernel.idx <- which(kernels == fit$kernel) - 1 + kernel.idx <- which(kernels == object$kernel) - 1 y.pred.c <- .Call("R_gensvm_predict_kernels", as.matrix(x.test), as.matrix(x.train), - as.matrix(fit$V), - as.integer(nrow(fit$V)), - as.integer(ncol(fit$V)), + as.matrix(object$V), + as.integer(nrow(object$V)), + as.integer(ncol(object$V)), as.integer(nrow(x.train)), as.integer(nrow(x.test)), - as.integer(fit$n.features), - as.integer(fit$n.classes), + as.integer(object$n.features), + as.integer(object$n.classes), as.integer(kernel.idx), - fit$gamma, - fit$coef, - fit$degree, - fit$kernel.eigen.cutoff + object$gamma, + object$coef, + object$degree, + object$kernel.eigen.cutoff ) } - yhat <- fit$classes[y.pred.c] + yhat <- object$classes[y.pred.c] if (add.rownames) { yhat <- matrix(yhat, length(yhat), 1) diff --git a/R/predict.gensvm.grid.R b/R/predict.gensvm.grid.R index df48d76..acc838f 100644 --- a/R/predict.gensvm.grid.R +++ b/R/predict.gensvm.grid.R @@ -7,9 +7,9 @@ #' this model is only available if \code{refit=TRUE} was specified in the #' \code{\link{gensvm.grid}} call (the default). #' -#' @param grid A \code{gensvm.grid} object trained with \code{refit=TRUE} -#' @param newx Matrix of new values for \code{x} for which predictions need to -#' be computed. +#' @param object A \code{gensvm.grid} object trained with \code{refit=TRUE} +#' @param newdata Matrix of new values for \code{x} for which predictions need +#' to be computed. #' @param \dots further arguments are passed to predict.gensvm() #' #' @return a vector of class labels, with the same type as the original class @@ -44,12 +44,12 @@ #' # predict training sample #' y.hat <- predict(grid, x) #' -predict.gensvm.grid <- function(grid, newx, ...) +predict.gensvm.grid <- function(object, newdata, ...) { - if (is.null(grid$best.estimator)) { + if (is.null(object$best.estimator)) { cat("Error: Can't predict, the best.estimator element is NULL\n") return } - return(predict(grid$best.estimator, newx, ...)) + return(predict(object$best.estimator, newdata, ...)) } diff --git a/R/print.gensvm.R b/R/print.gensvm.R index 724806b..54397da 100644 --- a/R/print.gensvm.R +++ b/R/print.gensvm.R @@ -2,7 +2,7 @@ #' #' @description Prints a short description of the fitted GenSVM model #' -#' @param fit A \code{gensvm} object to print +#' @param x A \code{gensvm} object to print #' @param \dots further arguments are ignored #' #' @return returns the object passed as input. This can be useful for chaining @@ -36,41 +36,41 @@ #' fit <- gensvm(x, y) #' predict(print(fit), x) #' -print.gensvm <- function(fit, ...) +print.gensvm <- function(x, ...) { cat("Data:\n") - cat("\tn.objects:", fit$n.objects, "\n") - cat("\tn.features:", fit$n.features, "\n") - cat("\tn.classes:", fit$n.classes, "\n") - if (is.factor(fit$classes)) - cat("\tclasses:", levels(fit$classes), "\n") + cat("\tn.objects:", x$n.objects, "\n") + cat("\tn.features:", x$n.features, "\n") + cat("\tn.classes:", x$n.classes, "\n") + if (is.factor(x$classes)) + cat("\tclasses:", levels(x$classes), "\n") else - cat("\tclasses:", fit$classes, "\n") + cat("\tclasses:", x$classes, "\n") cat("Parameters:\n") - cat("\tp:", fit$p, "\n") - cat("\tlambda:", fit$lambda, "\n") - cat("\tkappa:", fit$kappa, "\n") - cat("\tepsilon:", fit$epsilon, "\n") - cat("\tweights:", fit$weights, "\n") - cat("\tmax.iter:", fit$max.iter, "\n") - cat("\trandom.seed:", fit$random.seed, "\n") - if (is.factor(fit$kernel)) { - cat("\tkernel:", levels(fit$kernel)[as.numeric(fit$kernel)], "\n") + cat("\tp:", x$p, "\n") + cat("\tlambda:", x$lambda, "\n") + cat("\tkappa:", x$kappa, "\n") + cat("\tepsilon:", x$epsilon, "\n") + cat("\tweights:", x$weights, "\n") + cat("\tmax.iter:", x$max.iter, "\n") + cat("\trandom.seed:", x$random.seed, "\n") + if (is.factor(x$kernel)) { + cat("\tkernel:", levels(x$kernel)[as.numeric(x$kernel)], "\n") } else { - cat("\tkernel:", fit$kernel, "\n") + cat("\tkernel:", x$kernel, "\n") } - if (fit$kernel %in% c("poly", "rbf", "sigmoid")) { - cat("\tkernel.eigen.cutoff:", fit$kernel.eigen.cutoff, "\n") - cat("\tgamma:", fit$gamma, "\n") + if (x$kernel %in% c("poly", "rbf", "sigmoid")) { + cat("\tkernel.eigen.cutoff:", x$kernel.eigen.cutoff, "\n") + cat("\tgamma:", x$gamma, "\n") } - if (fit$kernel %in% c("poly", "sigmoid")) - cat("\tcoef:", fit$coef, "\n") - if (fit$kernel == 'poly') - cat("\tdegree:", fit$degree, "\n") + if (x$kernel %in% c("poly", "sigmoid")) + cat("\tcoef:", x$coef, "\n") + if (x$kernel == 'poly') + cat("\tdegree:", x$degree, "\n") cat("Results:\n") - cat("\ttime:", fit$training.time, "\n") - cat("\tn.iter:", fit$n.iter, "\n") - cat("\tn.support:", fit$n.support, "\n") + cat("\ttime:", x$training.time, "\n") + cat("\tn.iter:", x$n.iter, "\n") + cat("\tn.support:", x$n.support, "\n") - invisible(fit) + invisible(x) } diff --git a/R/print.gensvm.grid.R b/R/print.gensvm.grid.R index 5e4c5da..3e1bb69 100644 --- a/R/print.gensvm.grid.R +++ b/R/print.gensvm.grid.R @@ -2,7 +2,7 @@ #' #' @description Prints the summary of the fitted GenSVMGrid model #' -#' @param grid a \code{gensvm.grid} object to print +#' @param x a \code{gensvm.grid} object to print #' @param \dots further arguments are ignored #' #' @return returns the object passed as input @@ -32,38 +32,38 @@ #' grid <- gensvm.grid(x, y) #' print(grid) #' -print.gensvm.grid <- function(grid, ...) +print.gensvm.grid <- function(x, ...) { cat("Data:\n") - cat("\tn.objects:", grid$n.objects, "\n") - cat("\tn.features:", grid$n.features, "\n") - cat("\tn.classes:", grid$n.classes, "\n") - if (is.factor(grid$classes)) - cat("\tclasses:", levels(grid$classes), "\n") + cat("\tn.objects:", x$n.objects, "\n") + cat("\tn.features:", x$n.features, "\n") + cat("\tn.classes:", x$n.classes, "\n") + if (is.factor(x$classes)) + cat("\tclasses:", levels(x$classes), "\n") else - cat("\tclasses:", grid$classes, "\n") + cat("\tclasses:", x$classes, "\n") cat("Config:\n") - cat("\tNumber of cv splits:", grid$n.splits, "\n") - not.run <- sum(is.na(grid$cv.results$rank.test.score)) + cat("\tNumber of cv splits:", x$n.splits, "\n") + not.run <- sum(is.na(x$cv.results$rank.test.score)) if (not.run > 0) { - cat("\tParameter grid size:", dim(grid$param.grid)[1]) + cat("\tParameter grid size:", dim(x$param.grid)[1]) cat(" (", not.run, " incomplete)", sep="") cat("\n") } else { - cat("\tParameter grid size:", dim(grid$param.grid)[1], "\n") + cat("\tParameter grid size:", dim(x$param.grid)[1], "\n") } cat("Results:\n") - cat("\tTotal grid search time:", grid$total.time, "\n") - if (!is.na(grid$best.index)) { - best <- grid$cv.results[grid$best.index, ] + cat("\tTotal grid search time:", x$total.time, "\n") + if (!is.na(x$best.index)) { + best <- x$cv.results[x$best.index, ] cat("\tBest mean test score:", best$mean.test.score, "\n") cat("\tBest mean fit time:", best$mean.fit.time, "\n") - for (name in colnames(grid$best.params)) { - val <- grid$best.params[[name]] + for (name in colnames(x$best.params)) { + val <- x$best.params[[name]] val <- if(is.factor(val)) levels(val)[val] else val cat("\tBest parameter", name, "=", val, "\n") } } - invisible(grid) + invisible(x) } |
