forked from nyu-mll/jiant
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_edge_data_labels.py
More file actions
executable file
·70 lines (58 loc) · 1.95 KB
/
get_edge_data_labels.py
File metadata and controls
executable file
·70 lines (58 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
# Helper script to extract set of labels from edge probing data.
#
# Usage:
# python get_edge_data_labels.py -o /path/to/edge/probing/data/labels.txt \
# -i /path/to/edge/probing/data/*.json
#
import argparse
import collections
import json
import logging as log
import os
import sys
from typing import Type
from tqdm import tqdm
from jiant.utils import utils
log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)
def count_labels(fname: str) -> Type[collections.Counter]:
"""Count labels across all targets in a file of edge probing examples."""
label_ctr = collections.Counter()
record_iter = utils.load_json_data(fname)
for record in tqdm(record_iter):
for target in record["targets"]:
label = target["label"]
if isinstance(label, str):
label = [label]
label_ctr.update(label)
return label_ctr
def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("-o", dest="output", type=str, required=True, help="Output file.")
parser.add_argument("-i", dest="inputs", type=str, nargs="+", help="Input files.")
parser.add_argument(
"-s",
dest="special_tokens",
type=str,
nargs="*",
default=["-"],
help="Additional special tokens to add at beginning " "of vocab list.",
)
args = parser.parse_args(args)
label_ctr = collections.Counter()
for fname in args.inputs:
log.info("Counting labels in %s", fname)
label_ctr.update(count_labels(fname))
all_labels = args.special_tokens + sorted(label_ctr.keys())
log.info(
"%d labels in total (%d special + %d found)",
len(all_labels),
len(args.special_tokens),
len(label_ctr),
)
with open(args.output, "w") as fd:
for label in all_labels:
fd.write(label + "\n")
if __name__ == "__main__":
main(sys.argv[1:])
sys.exit(0)