-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdialmonkey_extract.py
More file actions
165 lines (144 loc) · 6.91 KB
/
dialmonkey_extract.py
File metadata and controls
165 lines (144 loc) · 6.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from dialmonkey.nlu.public_transport_cs import PublicTransportCSNLU, string_func
from dialmonkey.da import DA
from ufal.morphodita import Tokenizer, Forms, TokenRanges
import copy
import re
class TinkeredPTCSNLU(PublicTransportCSNLU):
def __init__(self):
super().__init__(config={"utt2da": "utt2da.tsv"})
self.tokenizer = Tokenizer.newCzechTokenizer()
self.nonspeech_pattern = re.compile(r'\s*\([^()]*\)')
self.nonspeech_matched = set()
def tokenize(self, text):
self.tokenizer.setText(text)
forms = Forms()
tokens = TokenRanges()
result = []
while self.tokenizer.nextSentence(forms, tokens):
result.extend([text[token.start : token.start + token.length] for token in tokens])
return result
def tinkered_abstract_utterance(self, utterance):
"""
Return a list of possible abstractions of the utterance.
:param utterance: an Utterance instance
:return: a list of abstracted utterance, form, value, category label tuples
"""
abs_utts = copy.deepcopy(utterance)
ann_utts = copy.deepcopy(utterance)
form_value_mapping = dict()
category_labels = set()
abs_utt_lengths = [1] * len(abs_utts)
start = 0
while start < len(utterance):
end = len(utterance)
while end > start:
orig = tuple(utterance[start:end])
f = orig
last_is_punctuation = utterance[end-1][-1] in ['.', ',', '!', '?', ':', ';']
if last_is_punctuation:
f = tuple(utterance[start:end-1]+[utterance[end-1][:-1]])
last_punctuation = utterance[end-1][-1]
# found a form
if f in self.cldb.form2value2cl:
# use the 1st matching value (XXX there's no good way of disambiguating)
for v in self.cldb.form2value2cl[f]:
# get the categories
c = self.cldb.form2value2cl[f][v]
if not c: # no categories -- shouldn't happen
continue
elif len(c) > 1: # ambiguous categories -- try disambiguating
c = self.disambiguate_category(utterance, start, end, c)
else:
c = c[0]
if c.upper() == "STOP" or c.upper() == "CITY":
# Look only into STOPs and CITYies
form_value_mapping[' '.join(f)] = v
abs_utts = abs_utts.replace(list(orig), [c.upper() + '=' + v])
ann_utts = ann_utts.replace(list(orig), [c.upper() + '=' + ' '.join(f)])
abs_utt_lengths[start] = len(f)
category_labels.add(c.upper())
break
# skip all substring for this form
start = end
break
end -= 1
else:
start += 1
# normalize abstract utterance lengths
norm_abs_utt_lengths = []
i = 0
while i < len(abs_utt_lengths):
le = abs_utt_lengths[i]
norm_abs_utt_lengths.append(le)
i += le
return abs_utts, ann_utts, category_labels, norm_abs_utt_lengths, form_value_mapping
def tinkered_parse(self, dial: dict) -> list:
"""Parse an utterance into a dialogue act.
:rtype DialogueActConfusionNetwork
"""
utterance = dial['user']
res_da = DA()
dict_da = self.utt2da.get(str(utterance), None)
if dict_da:
dial.update({'nlu': dict_da, 'abutterance': utterance, 'anutterance': utterance, 'fv_mapping': dict(), 'abutterance_lenghts': [1] * len(self.tokenize(utterance))})
return dial
# Remove all nonspeech events and save what was removed for later check
utterance = self.nonspeech_pattern.sub(lambda match: (self.nonspeech_matched.add(match.group(0).strip()), '')[1], utterance)
# Capitalize the first letter of the sentence
utterance = utterance.strip()
utterance = utterance[0].upper() + utterance[1:]
dial['user'] = utterance
# dial["tokenized"] = self.tokenize(utterance)
# utterance = self.preprocessing.normalize(string_func.TokenList(self.tokenize(utterance)))
utterance = self.preprocessing.normalize(string_func.TokenList(utterance))
# dial["tokenized"] = list(utterance)
abutterance, anutterance, category_labels, abutterance_lenghts, form_value_mapping = self.tinkered_abstract_utterance(utterance)
dial["fv_mapping"] = form_value_mapping
dial["anutterance"] = list(anutterance)
self.parse_non_speech_events(utterance, res_da)
abutterance = self.handle_false_abstractions(abutterance)
category_labels.add('CITY')
category_labels.add('VEHICLE')
category_labels.add('NUMBER')
if len(res_da) == 0:
if 'STOP' in category_labels:
self.parse_stop(abutterance, res_da)
if 'CITY' in category_labels:
self.parse_city(abutterance, res_da)
if 'NUMBER' in category_labels:
self.parse_number(abutterance)
if any([word.startswith("TIME") for word in abutterance]):
category_labels.add('TIME')
if 'TIME' in category_labels:
self.parse_time(abutterance, res_da)
if 'DATE_REL' in category_labels:
self.parse_date_rel(abutterance, res_da)
if 'AMPM' in category_labels:
self.parse_ampm(abutterance, res_da)
if 'VEHICLE' in category_labels:
self.parse_vehicle(abutterance, res_da)
if 'TASK' in category_labels:
self.parse_task(abutterance, res_da)
if 'TRAIN_NAME' in category_labels:
self.parse_train_name(abutterance, res_da)
self.parse_meta(utterance, abutterance_lenghts, res_da)
res_da.merge_duplicate_dais()
dial['nlu'] = res_da
dial['abutterance'] = list(abutterance)
dial['abutterance_lenghts'] = abutterance_lenghts
return dial
if __name__ == "__main__":
nlu = TinkeredPTCSNLU()
res = nlu.tinkered_parse({"user": "chci jet z anděla na malostranskou. nebo vlastně z anděla"})
print(res)
res = nlu.tinkered_parse({"user": "z anděla na malostranské náměstí"})
print(res)
res = nlu.tinkered_parse({"user": "chci jet z Prahy přes Palmovku"})
print(res)
res = nlu.tinkered_parse({"user": "super, krásný, děkuju (breath) (noise)"})
print(res)
res = nlu.tinkered_parse({"user": "jsem zde"})
print(res)
res = nlu.tinkered_parse({"user": "v kolik přijedu na zastávku Praha, Zličín"})
print(res)
print(nlu.nonspeech_matched)