aboutsummaryrefslogtreecommitdiff
path: root/execs/R/cpdbench_ocp.R
blob: 308e5bacc41f3ac1adfdc8699d484bb4fc5f50e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#' ---
#' title: Wrapper for ocp package in TCPDBench
#' author: G.J.J. van den Burg
#' date: 2019-10-05
#' license: See the LICENSE file.
#' copyright: 2019, The Alan Turing Institute
#' ---

library(argparse)
library(ocp)

load.utils <- function() {
    # get the name of the current script so we can load utils.R (yay, R!)
    cmd.args <- commandArgs(trailingOnly=F)
    file.arg <- "--file="
    this.script <- sub(file.arg, "", cmd.args[grep(file.arg, cmd.args)])
    this.dir <- dirname(this.script)
    utils.script <- file.path(this.dir, 'utils.R')
    source(utils.script)
}

parse.args <- function() {
    parser <- ArgumentParser(description='Wrapper for changepoint package')
    parser$add_argument('-i', 
                        '--input',
                        help='path to the input data file',
                        required=TRUE
    )
    parser$add_argument('-o',
                        '--output',
                        help='path to the output file')
    parser$add_argument('-l',
                        '--lambda',
                        help='lambda parameter for constant hazard function',
                        type='integer',
                        default=100
                        )
    parser$add_argument('--prior-a',
                        help='Prior alpha for student-t',
                        type='double',
                        default=1)
    parser$add_argument('--prior-b',
                        help='Prior beta for student-t',
                        type='double',
                        default=1
    )
    parser$add_argument('--prior-k',
                        help='Prior kappa for student-t',
                        type='double',
                        default=1
    )

    return(parser$parse_args())
}

main <- function()
{
    args <- parse.args()
    data <- load.dataset(args$input)

    # set the defaults that we don't change
    defaults <- list(missPts="none",
                     cpthreshold=0.5, # unused by us
                     truncRlim=10^(-4),
                     minRlength=1,
                     maxRlength=10^4, # bigger than any of our datasets
                     minsep=1,
                     maxsep=10^4 # bigger than any of our datasets
                     )
    defaults$multivariate = data$original$n_dim > 1

    # combine defaults and cmd args
    params <- make.param.list(args, defaults)

    # define our hazard function with the lambda in the parameters
    hazard_func <- function(x, lambda) {
        const_hazard(x, lambda=params$lambda)
    }

    # we only use the gaussian model since the data is scaled
    model.params <- list(list(m=0, k=params$prior_k, a=params$prior_a,
                                  b=params$prior_b))

    start.time <- Sys.time()
    result <- tryCatch({
        fit <- onlineCPD(data$mat, oCPD=NULL, missPts=params$missPts,
                         hazard_func=hazard_func, 
                         probModel=list("gaussian"),
                         init_params=model.params,
                         multivariate=params$multivariate,
                         cpthreshold=params$cpthreshold,
                         truncRlim=params$truncRlim,
                         minRlength=params$minRlength,
                         maxRlength=params$maxRlength,
                         minsep=params$minsep,
                         maxsep=params$maxsep
                         )
        locs <- as.vector(fit$changepoint_lists$maxCPs[[1]])
        list(locations=locs, error=NULL)
    }, error=function(e) {
        return(list(locations=NULL, error=e$message))
    })
    stop.time <- Sys.time()
    runtime <- difftime(stop.time, start.time, units="secs")

    if (!is.null(result$error))
        exit.with.error(data$original, args, params, result$error)

    # convert indices to 0-based indices
    locations <- as.list(result$locations - 1)

    exit.success(data$original, args, params, locations, runtime)
}

load.utils()
main()