Skip to content

Commit ddfdfaf

Browse files
authored
Merge pull request #6 from nus-cs3244-ml-singapore-7/train-xlm-roberta-sh-ner
Add scripts to run NER inference on entire Singapore Hansard dataset
2 parents 0b6dbb6 + 3dc7325 commit ddfdfaf

3 files changed

Lines changed: 133 additions & 5 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# singapore-hansard-nlp
22
Singapore Hansard NLP
33

4+
[singapore-hansard-sentiment-ner-final.zip](https://drive.google.com/file/d/1xWDplG7ythfnv3EZjxhjEmtUHwS9st-E/view?usp=sharing)
5+
46
Note that 28 JSON files in the 12th session from 2011-10-10 to 2012-08-13 are excluded
57
as they use an old format that is difficult to parse.
68

@@ -24,8 +26,6 @@ python test_sentiment.py input.json models/xlm-roberta-base-sst-2-sh-sentiment
2426

2527
[xlm-roberta-base-sst-2-sh-sentiment.tar.xz](https://drive.google.com/file/d/1toqvkwWjXuHH0EIHHjJv9V9x5FZ-0Pba/view?usp=sharing)
2628

27-
[singapore-hansard-sentiment-final.zip](https://drive.google.com/file/d/14yZRPLvQ7usliO1WOdFKmqoWgtJjax0C/view?usp=sharing)
28-
2929
### Raw Results
3030

3131
xlm-roberta-base-sst-2

xlm-roberta-sh-ner/classify_ner.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
import torch
6+
from seqeval import metrics
7+
from transformers import XLMRobertaForTokenClassification, XLMRobertaTokenizerFast
8+
9+
from sh_ner_dataset import id_to_label
10+
11+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12+
13+
# NOTE: Assumes that sentiment has already been classified and JSON file is updated
14+
15+
def create_arg_parser():
16+
parser = argparse.ArgumentParser(
17+
description='Perform NER on Singapore Hansard using XLM-RoBERTa model')
18+
19+
parser.add_argument('input_dir_path', type=str,
20+
help='Path of directroy to read JSON files from.')
21+
22+
parser.add_argument('output_dir_path', type=str,
23+
help='Path of directroy to write JSON files to.')
24+
25+
parser.add_argument('model_name_or_dir', type=str,
26+
help='Name or directory of model.')
27+
28+
return parser
29+
30+
def main(input_dir_path, output_dir_path, model_name_or_dir):
31+
tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name_or_dir)
32+
model = XLMRobertaForTokenClassification.from_pretrained(model_name_or_dir).to(DEVICE)
33+
model.eval()
34+
35+
os.makedirs(output_dir_path, exist_ok=True)
36+
37+
for file_name in os.listdir(input_dir_path):
38+
if file_name.endswith('.json'):
39+
count = 0
40+
input_file_path = os.path.join(input_dir_path, file_name)
41+
with open(input_file_path) as json_file:
42+
data = json.load(json_file)
43+
44+
for session in data['sessions']:
45+
for speech in session['speeches']:
46+
for text_sentiment in speech['content']:
47+
text = text_sentiment['text']
48+
inputs = tokenizer(
49+
text,
50+
padding=False,
51+
truncation=True,
52+
return_special_tokens_mask=True,
53+
return_offsets_mapping=True)
54+
55+
with torch.no_grad():
56+
input_ids = torch.tensor(inputs['input_ids']).unsqueeze(0).to(DEVICE)
57+
attention_mask = torch.tensor(inputs['attention_mask']).unsqueeze(0).to(DEVICE)
58+
outputs = model(input_ids, attention_mask).logits
59+
predictions = torch.argmax(outputs, dim=2)[0].detach().cpu().numpy()
60+
61+
special_tokens_mask = inputs['special_tokens_mask']
62+
offset_mapping = inputs['offset_mapping']
63+
64+
start_index = 0
65+
end_index = 0
66+
previous_iob_entity = None
67+
68+
entities = []
69+
70+
for i in range(len(predictions)):
71+
if special_tokens_mask[i] == 0 and predictions[i] != 0:
72+
iob_entity = id_to_label(predictions[i])
73+
74+
if iob_entity[:2] == 'B-' or previous_iob_entity is None or iob_entity[2:] != previous_iob_entity[2:]:
75+
if previous_iob_entity is not None:
76+
label = previous_iob_entity[2:]
77+
start = start_index
78+
end = end_index
79+
80+
if text[start] == ' ':
81+
start += 1
82+
83+
word = text[start:end]
84+
85+
entities.append({
86+
'word': word,
87+
'start': start,
88+
'end': end,
89+
'label': label,
90+
})
91+
92+
start_index = offset_mapping[i][0]
93+
94+
end_index = offset_mapping[i][1]
95+
previous_iob_entity = iob_entity
96+
97+
if previous_iob_entity is not None:
98+
label = previous_iob_entity[2:]
99+
start = start_index
100+
end = end_index
101+
102+
if text[start] == ' ':
103+
start += 1
104+
105+
word = text[start:end]
106+
107+
entities.append({
108+
'word': word,
109+
'start': start,
110+
'end': end,
111+
'label': label,
112+
})
113+
114+
text_sentiment['entities'] = entities
115+
count += 1
116+
117+
output_file_path = os.path.join(output_dir_path, file_name)
118+
with open(output_file_path, 'w') as json_file:
119+
json.dump(data, json_file)
120+
121+
print("File: {}, Count: {}".format(file_name, count))
122+
123+
if __name__ == '__main__':
124+
parser = create_arg_parser()
125+
args = parser.parse_args()
126+
main(args.input_dir_path, args.output_dir_path, args.model_name_or_dir)

xlm-roberta-sh-ner/sh_ner_dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def create_sh_ner_dataset(train_json_path, val_json_path, tokenizer):
7676
return sh_ner_train_dataset, sh_ner_val_dataset
7777

7878
class SingaporeHansardNerDataset(Dataset):
79-
def __init__(self, sentences, sentences_entities, tokenizer):
79+
def __init__(self, sentences, sentences_entities, tokenizer, trainer=True):
8080
self.sentences = sentences
8181
self.sentences_entities = sentences_entities
8282

@@ -117,8 +117,10 @@ def __init__(self, sentences, sentences_entities, tokenizer):
117117
else:
118118
sentence_labels[i] = 'MASK'
119119

120-
del sentence_encodings['special_tokens_mask']
121-
del sentence_encodings['offset_mapping']
120+
if trainer:
121+
del sentence_encodings['special_tokens_mask']
122+
del sentence_encodings['offset_mapping']
123+
122124
sentence_labels = [self.label2id[label] for label in sentence_labels]
123125

124126
encodings.append(sentence_encodings)

0 commit comments

Comments
 (0)