1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
#' ---
#' title: Wrapper for wbs package in TCPDBench
#' author: G.J.J. van den Burg
#' date: 2019-09-28
#' license: See the LICENSE file.
#' copyright: 2019, The Alan Turing Institute
#' ---
library(argparse)
library(wbs)
load.utils <- function() {
# get the name of the current script so we can load utils.R (yay, R!)
cmd.args <- commandArgs(trailingOnly=F)
file.arg <- "--file="
this.script <- sub(file.arg, "", cmd.args[grep(file.arg, cmd.args)])
this.dir <- dirname(this.script)
utils.script <- file.path(this.dir, 'utils.R')
source(utils.script)
}
parse.args <- function() {
parser <- ArgumentParser(description='Wrapper for wbs package')
parser$add_argument('-i',
'--input',
help='path to the input data file',
required=TRUE
)
parser$add_argument('-o',
'--output',
help='path to the output file'
)
parser$add_argument('-K', '--Kmax', choices=c('default', 'max'),
help='the maximum number of changepoints',
default='max')
parser$add_argument('-p', '--penalty', choices=c('SSIC', 'BIC', 'MBIC'),
help='The penalty to use in WBS')
parser$add_argument("-g", "--integrated", choices=c("true", "false"),
help="Whether to use integrated WBS or not")
return(parser$parse_args())
}
main <- function() {
args <- parse.args()
# load the data
data <- load.dataset(args$input)
# copy defaults from the wbs package and set params
defaults <- list(M=5000, rand.intervals=T)
if (args$Kmax == 'default')
args$Kmax <- 50
else
args$Kmax <- data$original$n_obs
if (args$integrated == "true")
args$integrated = TRUE
else
args$integrated = FALSE
params <- make.param.list(args, defaults)
if (data$original$n_dim > 1) {
# wbs package doesn't handle multidimensional data
exit.error.multidim(data$original, args, params)
}
vec <- as.vector(data$mat)
start.time <- Sys.time()
# We use the SSIC penalty as this is used in the WBS paper and is the
# default in the WBS package (for plot.wbs, for instance).
result <- tryCatch({
out <- wbs(vec, M=params$M, rand.intervals=params$rand.intervals,
integrated=params$integrated)
cpt <- changepoints(out, Kmax=params$Kmax)
if (params$penalty == "SSIC")
locs <- cpt$cpt.ic$ssic.penalty
else if (params$penalty == "BIC")
locs <- cpt$cpt.ic$bic.penalty
else if (params$penalty == "MBIC")
locs <- cpt$cpt.ic$mbic.penalty
locs <- sort(locs)
list(locations=locs, error=NULL)
}, error=function(e) {
return(list(locations=NULL, error=e$message))
})
stop.time <- Sys.time()
runtime <- difftime(stop.time, start.time, units='secs')
if (!is.null(result$error)) {
exit.with.error(data$original, args, params, result$error)
}
# convert to 0-based indices.
locations <- as.list(result$locations - 1)
exit.success(data$original, args, params, locations, runtime)
}
load.utils()
main()
|