Skip to content

Commit 5d5d445

Browse files
committed
Add benchmark for oov handling
1 parent 4a35855 commit 5d5d445

4 files changed

Lines changed: 1740 additions & 0 deletions

File tree

benchmark/oov/handle_oov.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import re
2+
import os
3+
import sys
4+
import traceback
5+
from collections import Counter
6+
7+
class Model:
8+
@staticmethod
9+
def from_name(name, kiwi_model_path=None, bareun_api_key=None):
10+
if name == 'kiwi': return KiwiModel(kiwi_model_path)
11+
if name == 'kiwi-largest': return KiwiModel(kiwi_model_path, 'largest')
12+
if name == 'kiwi-cong': return KiwiModel(kiwi_model_path, 'cong')
13+
if name == 'kiwi-cong-global': return KiwiModel(kiwi_model_path, 'cong-global')
14+
if name == 'komoran': return KomoranModel()
15+
if name == 'kkma': return KkmaModel()
16+
if name == 'hannanum': return HannanumModel()
17+
if name == 'mecab': return MecabModel()
18+
if name == 'okt': return OktModel()
19+
if name == 'khaiii': return KhaiiiModel()
20+
if name == 'bareun': return BareunModel(bareun_api_key)
21+
raise ValueError(f'Unknown model name: {name}')
22+
23+
def _convert(self, morph):
24+
return morph
25+
26+
def _tokenize(self, text):
27+
raise NotImplementedError()
28+
29+
def _is_noun(self, tag):
30+
return tag.startswith('NN')
31+
32+
def tokenize(self, text):
33+
return list(map(self._convert, self._tokenize(text)))
34+
35+
def nouns(self, text):
36+
return [form for form, tag in self.tokenize(text) if self._is_noun(tag)]
37+
38+
class KiwiModel(Model):
39+
40+
def __init__(self, model_path=None, model_type='none', **kwargs):
41+
import kiwipiepy
42+
from kiwipiepy import Kiwi
43+
print("Initialize kiwipiepy ({})".format(kiwipiepy.__version__), file=sys.stderr)
44+
self._mdl = Kiwi(model_path=model_path, model_type=model_type, **kwargs)
45+
self.oov_handling = None
46+
self.config = None
47+
48+
def _convert(self, morph):
49+
return morph.form, morph.tag
50+
51+
def _tokenize(self, text):
52+
return self._mdl.tokenize(text, oov_handling=self.oov_handling, override_config=self.config)
53+
54+
class KomoranModel(Model):
55+
def __init__(self):
56+
import konlpy
57+
from konlpy import tag
58+
print("Initialize Komoran from konlpy ({})".format(konlpy.__version__), file=sys.stderr)
59+
self._mdl = tag.Komoran()
60+
61+
def _tokenize(self, text):
62+
try:
63+
return self._mdl.pos(text)
64+
except:
65+
return []
66+
67+
class KkmaModel(Model):
68+
def __init__(self):
69+
import konlpy
70+
from konlpy import tag
71+
print("Initialize Kkma from konlpy ({})".format(konlpy.__version__), file=sys.stderr)
72+
self._mdl = tag.Kkma()
73+
74+
def _tokenize(self, text):
75+
try:
76+
return self._mdl.pos(text)
77+
except:
78+
return []
79+
80+
class MecabModel(Model):
81+
def __init__(self):
82+
import konlpy
83+
from konlpy import tag
84+
print("Initialize Mecab from konlpy ({})".format(konlpy.__version__), file=sys.stderr)
85+
self._mdl = tag.Mecab()
86+
87+
def _tokenize(self, text):
88+
try:
89+
return self._mdl.pos(text, split_inflect=True)
90+
except TypeError:
91+
return self._mdl.pos(text)
92+
93+
class HannanumModel(Model):
94+
95+
def __init__(self):
96+
import konlpy
97+
from konlpy import tag
98+
print("Initialize Hannanum from konlpy ({})".format(konlpy.__version__), file=sys.stderr)
99+
self._mdl = tag.Hannanum()
100+
101+
def _convert(self, morph):
102+
if morph[1] == 'P':
103+
return morph[0], 'VV'
104+
return morph[0], morph[1]
105+
106+
def _is_noun(self, tag):
107+
return tag == 'N'
108+
109+
def _tokenize(self, text):
110+
return self._mdl.pos(text)
111+
112+
class OktModel(Model):
113+
114+
def __init__(self):
115+
import konlpy
116+
from konlpy import tag
117+
print("Initialize Okt from konlpy ({})".format(konlpy.__version__), file=sys.stderr)
118+
self._mdl = tag.Okt()
119+
120+
def _convert(self, morph):
121+
if morph[1] == 'Verb':
122+
return morph[0][:-1], 'VV'
123+
return morph[0], morph[1]
124+
125+
def _is_noun(self, tag):
126+
return tag in ('Noun', 'Foreign')
127+
128+
def _tokenize(self, text):
129+
return self._mdl.pos(text, stem=True)
130+
131+
class KhaiiiModel(Model):
132+
def __init__(self):
133+
from khaiii import KhaiiiApi
134+
self._mdl = KhaiiiApi()
135+
print("Initialize khaiii ({})".format(self._mdl.version()), file=sys.stderr)
136+
137+
def _tokenize(self, text):
138+
return [(morph.lex, morph.tag) for word in self._mdl.analyze(text) for morph in word.morphs]
139+
140+
class BareunModel(Model):
141+
def __init__(self, api_key, host='localhost', port=5656) -> None:
142+
import bareunpy as brn
143+
self._mdl = brn.Tagger(api_key, host, port)
144+
print(f"Initialize Bareun from bareunpy (version={brn.version}, bareun_version={brn.bareun_version})", file=sys.stderr)
145+
146+
def _tokenize(self, text):
147+
return self._mdl.tag(text).pos()
148+
149+
def load_dataset(path):
150+
tag_pattern = re.compile(r'<n(?: e="([^"]+)")?>(.*?)</n>')
151+
152+
ret = []
153+
for line in open(path, encoding='utf-8'):
154+
golds = Counter()
155+
gold_types = {}
156+
for m in tag_pattern.finditer(line):
157+
form = m.group(2)
158+
etype = m.group(1)
159+
gold_types[form] = etype
160+
golds[form] += 1
161+
162+
exam = tag_pattern.sub(r'\2', line).rstrip()
163+
ret.append((golds, gold_types, exam))
164+
return ret
165+
166+
def evaluate(dataset, model, score_by_type=False, result_output=None):
167+
gold_per_type = Counter()
168+
pred_per_type = Counter()
169+
correct_per_type = Counter()
170+
gold_chr_per_type = Counter()
171+
pred_chr_per_type = Counter()
172+
correct_chr_per_type = Counter()
173+
174+
results = []
175+
for golds, gold_types, exam in dataset:
176+
result = model.nouns(exam)
177+
if result_output is not None:
178+
tokens = model.tokenize(exam)
179+
print(' '.join(f'{form}/{t}' for form, t in gold_types.items() if t),
180+
exam,
181+
' '.join(f'{form}/{tag}' for form, tag in tokens), sep='\t', file=result_output, flush=True)
182+
preds = Counter(r.replace(' ', '') for r in result)
183+
184+
for form, count in golds.items():
185+
gold_per_type[gold_types.get(form)] += count
186+
gold_chr_per_type[gold_types.get(form)] += len(form) * count
187+
188+
for form, count in preds.items():
189+
pred_per_type[gold_types.get(form)] += count
190+
pred_chr_per_type[gold_types.get(form)] += len(form) * count
191+
192+
for form, count in (golds & preds).items():
193+
correct_per_type[gold_types.get(form)] += count
194+
correct_chr_per_type[gold_types.get(form)] += len(form) * count
195+
196+
results.append(result)
197+
198+
all_types = sorted(filter(None, gold_per_type))
199+
200+
scores = {}
201+
scores['labeled_recall'] = sum(correct_per_type[t] for t in all_types) / max(sum(gold_per_type[t] for t in all_types), 1)
202+
if score_by_type:
203+
scores.update({f'labeled_recall {t}': correct_per_type[t] / max(gold_per_type[t], 1) for t in all_types})
204+
205+
scores['precision'] = p = sum(correct_per_type.values()) / max(sum(pred_per_type.values()), 1)
206+
scores['recall'] = r = sum(correct_per_type.values()) / max(sum(gold_per_type.values()), 1)
207+
scores['f1'] = 2 * p * r / max(p + r, 1)
208+
if score_by_type:
209+
for t in all_types:
210+
scores[f'precision {t}'] = p = correct_per_type[t] / max(pred_per_type[t], 1)
211+
scores[f'recall {t}'] = r = correct_per_type[t] / max(gold_per_type[t], 1)
212+
scores[f'f1 {t}'] = 2 * p * r / max(p + r, 1)
213+
214+
scores['chr_precision'] = p = sum(correct_chr_per_type.values()) / max(sum(pred_chr_per_type.values()), 1)
215+
scores['chr_recall'] = r = sum(correct_chr_per_type.values()) / max(sum(gold_chr_per_type.values()), 1)
216+
scores['chr_f1'] = 2 * p * r / max(p + r, 1)
217+
if score_by_type:
218+
for t in all_types:
219+
scores[f'chr_precision {t}'] = p = correct_chr_per_type[t] / max(pred_chr_per_type[t], 1)
220+
scores[f'chr_recall {t}'] = r = correct_chr_per_type[t] / max(gold_chr_per_type[t], 1)
221+
scores[f'chr_f1 {t}'] = 2 * p * r / max(p + r, 1)
222+
return scores, results
223+
224+
def test_kiwi_oov_handling(args):
225+
from kiwipiepy import KiwiConfig
226+
model = KiwiModel(model_path=args.kiwi_model_path, load_default_dict=args.test_kiwi_with_dictionary, load_multi_dict=args.test_kiwi_with_dictionary)
227+
228+
settings = [
229+
*[(f'rule (bias={bias})', {'oov_handling': 'rule', 'config': KiwiConfig(oov_rule_bias=bias)}) for bias in range(-3, 4)],
230+
*[(f'c (bias={bias})', {'oov_handling': 'chr', 'config': KiwiConfig(oov_chr_bias=bias)}) for bias in range(-3, 4)],
231+
*[(f'cf (bias={bias}, global_weight=35)', {'oov_handling': 'chr_freq', 'config': KiwiConfig(oov_chr_bias=bias, oov_global_weight=35)}) for bias in range(-3, 4)],
232+
*[(f'cf (bias=-3, global_weight={w})', {'oov_handling': 'chr_freq', 'config': KiwiConfig(oov_chr_bias=-3, oov_global_weight=w)}) for w in range(5, 105, 5)],
233+
]
234+
235+
if args.result_output:
236+
result_outputs = [open(f'{os.path.splitext(args.result_output)[0]}_{s[0]}{os.path.splitext(args.result_output)[1]}', 'w', encoding='utf-8') for s in settings]
237+
238+
print('', '', *[s[0] for s in settings], sep='\t')
239+
for dataset in args.datasets:
240+
ds = load_dataset(dataset)
241+
all_scores = []
242+
all_results = [ds]
243+
for i, (name, params) in enumerate(settings):
244+
model.oov_handling = params['oov_handling']
245+
model.config = params.get('config')
246+
score, results = evaluate(ds, model, score_by_type=args.score_by_type, result_output=result_outputs[i] if args.result_output else None)
247+
all_scores.append(score)
248+
all_results.append(results)
249+
250+
for key in score:
251+
print(os.path.basename(dataset), f'({key})', *((f'{s[key]:.4f}' if s[key] is not None else '-') for s in all_scores), sep='\t')
252+
253+
if args.result_output:
254+
for f in result_outputs:
255+
f.close()
256+
257+
def main(args):
258+
if args.test_kiwi_oov_handling:
259+
return test_kiwi_oov_handling(args)
260+
261+
models = [Model.from_name(n,
262+
kiwi_model_path=args.kiwi_model_path,
263+
bareun_api_key=args.bareun_api_key) for n in args.target]
264+
265+
if args.result_output:
266+
if len(models) == 1:
267+
result_outputs = [open(args.result_output, 'w', encoding='utf-8')]
268+
else:
269+
result_outputs = [open(f'{os.path.splitext(args.result_output)[0]}_{n}{os.path.splitext(args.result_output)[1]}', 'w', encoding='utf-8') for n in args.target]
270+
271+
print('', '', *args.target, sep='\t')
272+
for dataset in args.datasets:
273+
ds = load_dataset(dataset)
274+
all_scores = []
275+
all_results = [ds]
276+
for i, model in enumerate(models):
277+
score, results = evaluate(ds, model, score_by_type=args.score_by_type, result_output=result_outputs[i] if args.result_output else None)
278+
all_scores.append(score)
279+
all_results.append(results)
280+
281+
for key in score:
282+
print(os.path.basename(dataset), f'({key})', *((f'{s[key]:.4f}' if s[key] is not None else '-') for s in all_scores), sep='\t')
283+
284+
if args.result_output:
285+
for f in result_outputs:
286+
f.close()
287+
288+
if __name__ == '__main__':
289+
import argparse
290+
291+
parser = argparse.ArgumentParser()
292+
parser.add_argument('datasets', nargs='+')
293+
parser.add_argument('--target', default=['kiwi'], nargs='+', choices=['kiwi', 'kiwi-largest', 'komoran', 'mecab', 'kkma', 'hannanum', 'okt', 'khaiii', 'bareun'])
294+
parser.add_argument('--kiwi-model-path')
295+
parser.add_argument('--bareun-api-key')
296+
parser.add_argument('--test-kiwi-oov-handling', action='store_true')
297+
parser.add_argument('--test-kiwi-with-dictionary', action='store_true')
298+
parser.add_argument('--score-by-type', action='store_true')
299+
parser.add_argument('--result-output')
300+
main(parser.parse_args())

0 commit comments

Comments
 (0)