aboutsummaryrefslogtreecommitdiff
path: root/execs/R/cpdbench_wbs.R
diff options
context:
space:
mode:
Diffstat (limited to 'execs/R/cpdbench_wbs.R')
-rw-r--r--execs/R/cpdbench_wbs.R102
1 files changed, 102 insertions, 0 deletions
diff --git a/execs/R/cpdbench_wbs.R b/execs/R/cpdbench_wbs.R
new file mode 100644
index 00000000..e858b4df
--- /dev/null
+++ b/execs/R/cpdbench_wbs.R
@@ -0,0 +1,102 @@
+#' ---
+#' title: Wrapper for wbs package in TCPDBench
+#' author: G.J.J. van den Burg
+#' date: 2019-09-28
+#' license: See the LICENSE file.
+#' copyright: 2019, The Alan Turing Institute
+#' ---
+
+library(argparse)
+library(wbs)
+
+load.utils <- function() {
+ # get the name of the current script so we can load utils.R (yay, R!)
+ cmd.args <- commandArgs(trailingOnly=F)
+ file.arg <- "--file="
+ this.script <- sub(file.arg, "", cmd.args[grep(file.arg, cmd.args)])
+ this.dir <- dirname(this.script)
+ utils.script <- file.path(this.dir, 'utils.R')
+ source(utils.script)
+}
+
+parse.args <- function() {
+ parser <- ArgumentParser(description='Wrapper for wbs package')
+ parser$add_argument('-i',
+ '--input',
+ help='path to the input data file',
+ required=TRUE
+ )
+ parser$add_argument('-o',
+ '--output',
+ help='path to the output file'
+ )
+ parser$add_argument('-K', '--Kmax', choices=c('default', 'max'),
+ help='the maximum number of changepoints',
+ default='max')
+ parser$add_argument('-p', '--penalty', choices=c('SSIC', 'BIC', 'MBIC'),
+ help='The penalty to use in WBS')
+ parser$add_argument("-g", "--integrated", choices=c("true", "false"),
+ help="Whether to use integrated WBS or not")
+ return(parser$parse_args())
+}
+
+main <- function() {
+ args <- parse.args()
+
+ # load the data
+ data <- load.dataset(args$input)
+
+ # copy defaults from the wbs package and set params
+ defaults <- list(M=5000, rand.intervals=T)
+ if (args$Kmax == 'default')
+ args$Kmax <- 50
+ else
+ args$Kmax <- data$original$n_obs
+
+ if (args$integrated == "true")
+ args$integrated = TRUE
+ else
+ args$integrated = FALSE
+ params <- make.param.list(args, defaults)
+
+ if (data$original$n_dim > 1) {
+ # wbs package doesn't handle multidimensional data
+ exit.error.multidim(data$original, args, params)
+ }
+
+ vec <- as.vector(data$mat)
+ start.time <- Sys.time()
+
+ # We use the SSIC penalty as this is used in the WBS paper and is the
+ # default in the WBS package (for plot.wbs, for instance).
+
+ result <- tryCatch({
+ out <- wbs(vec, M=params$M, rand.intervals=params$rand.intervals,
+ integrated=params$integrated)
+ cpt <- changepoints(out, Kmax=params$Kmax)
+ if (params$penalty == "SSIC")
+ locs <- cpt$cpt.ic$ssic.penalty
+ else if (params$penalty == "BIC")
+ locs <- cpt$cpt.ic$bic.penalty
+ else if (params$penalty == "MBIC")
+ locs <- cpt$cpt.ic$mbic.penalty
+ locs <- sort(locs)
+ list(locations=locs, error=NULL)
+ }, error=function(e) {
+ return(list(locations=NULL, error=e$message))
+ })
+ stop.time <- Sys.time()
+ runtime <- difftime(stop.time, start.time, units='secs')
+
+ if (!is.null(result$error)) {
+ exit.with.error(data$original, args, params, result$error)
+ }
+
+ # convert to 0-based indices.
+ locations <- as.list(result$locations - 1)
+
+ exit.success(data$original, args, params, locations, runtime)
+}
+
+load.utils()
+main()