aboutsummaryrefslogtreecommitdiff
path: root/analysis/scripts/descriptive_annotations.py
blob: 2afdc4222e1b17fd213d9b593f43d0c93f4cb406 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding: utf-8 -*-

"""Extract descriptive statistics for the time series

This script is used to extract descriptive statistics about the number of 
annotations from the summary files.

Author: Gertjan van den Burg
Copyright (c) 2020 - The Alan Turing Institute
License: See the LICENSE file.

"""


import argparse
import json
import os
import statistics

N_DATASETS = 42


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-s",
        "--summary-dir",
        help="Directory with summary files",
        required=True,
    )
    parser.add_argument(
        "-t",
        "--type",
        help="Type of statistic to compute",
        choices=["min", "max", "mean", "std"],
        required=True,
    )
    return parser.parse_args()


def load_unique_annotations(summary_dir):
    files = os.listdir(summary_dir)
    assert len(files) == N_DATASETS

    n_uniq_anno = []
    for f in sorted(files):
        path = os.path.join(summary_dir, f)
        with open(path, "r") as fp:
            data = json.load(fp)

        all_anno = set()
        for annotations in data["annotations"].values():
            for cp in annotations:
                all_anno.add(cp)
        n_uniq_anno.append(len(all_anno))
    return n_uniq_anno


def main():
    args = parse_args()
    if args.type == "max":
        func = max
    elif args.type == "mean":
        func = statistics.mean
    elif args.type == "std":
        func = statistics.stdev
    elif args.type == "min":
        func = min
    else:
        raise ValueError("Unknown type")

    n_uniq_anno = load_unique_annotations(args.summary_dir)
    if args.type in ["min", "max"]:
        print("%i%%" % func(n_uniq_anno))
    else:
        print("%.1f%%" % func(n_uniq_anno))


if __name__ == "__main__":
    main()