aboutsummaryrefslogtreecommitdiff
path: root/execs/R/cpdbench_wbs.R
blob: e858b4dfdc8b6015807ca091eeb72192efafd542 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#' ---
#' title: Wrapper for wbs package in TCPDBench
#' author: G.J.J. van den Burg
#' date: 2019-09-28
#' license: See the LICENSE file.
#' copyright: 2019, The Alan Turing Institute
#' ---

library(argparse)
library(wbs)

load.utils <- function() {
    # get the name of the current script so we can load utils.R (yay, R!)
    cmd.args <- commandArgs(trailingOnly=F)
    file.arg <- "--file="
    this.script <- sub(file.arg, "", cmd.args[grep(file.arg, cmd.args)])
    this.dir <- dirname(this.script)
    utils.script <- file.path(this.dir, 'utils.R')
    source(utils.script)
}

parse.args <- function() {
    parser <- ArgumentParser(description='Wrapper for wbs package')
    parser$add_argument('-i', 
                        '--input',
                        help='path to the input data file',
                        required=TRUE
    )
    parser$add_argument('-o',
                        '--output',
                        help='path to the output file'
    )
    parser$add_argument('-K', '--Kmax', choices=c('default', 'max'),
                        help='the maximum number of changepoints',
                        default='max')
    parser$add_argument('-p', '--penalty', choices=c('SSIC', 'BIC', 'MBIC'),
                        help='The penalty to use in WBS')
    parser$add_argument("-g", "--integrated", choices=c("true", "false"),
                        help="Whether to use integrated WBS or not")
    return(parser$parse_args())
}

main <- function() {
    args <- parse.args()

    # load the data
    data <- load.dataset(args$input)

    # copy defaults from the wbs package and set params
    defaults <- list(M=5000, rand.intervals=T)
    if (args$Kmax == 'default')
        args$Kmax <- 50
    else
        args$Kmax <- data$original$n_obs

    if (args$integrated == "true")
        args$integrated = TRUE
    else
        args$integrated = FALSE
    params <- make.param.list(args, defaults)

    if (data$original$n_dim > 1) {
        # wbs package doesn't handle multidimensional data
        exit.error.multidim(data$original, args, params)
    }

    vec <- as.vector(data$mat)
    start.time <- Sys.time()

    # We use the SSIC penalty as this is used in the WBS paper and is the 
    # default in the WBS package (for plot.wbs, for instance).

    result <- tryCatch({
        out <- wbs(vec, M=params$M, rand.intervals=params$rand.intervals,
                   integrated=params$integrated)
        cpt <- changepoints(out, Kmax=params$Kmax)
        if (params$penalty == "SSIC")
            locs <- cpt$cpt.ic$ssic.penalty
        else if (params$penalty == "BIC")
            locs <- cpt$cpt.ic$bic.penalty
        else if (params$penalty == "MBIC")
            locs <- cpt$cpt.ic$mbic.penalty
        locs <- sort(locs)
        list(locations=locs, error=NULL)
    }, error=function(e) {
        return(list(locations=NULL, error=e$message))
    })
    stop.time <- Sys.time()
    runtime <- difftime(stop.time, start.time, units='secs')

    if (!is.null(result$error)) {
        exit.with.error(data$original, args, params, result$error)
    }

    # convert to 0-based indices.
    locations <- as.list(result$locations - 1)

    exit.success(data$original, args, params, locations, runtime)
}

load.utils()
main()