|
| 1 | +import re |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import traceback |
| 5 | +from collections import Counter |
| 6 | + |
| 7 | +class Model: |
| 8 | + @staticmethod |
| 9 | + def from_name(name, kiwi_model_path=None, bareun_api_key=None): |
| 10 | + if name == 'kiwi': return KiwiModel(kiwi_model_path) |
| 11 | + if name == 'kiwi-largest': return KiwiModel(kiwi_model_path, 'largest') |
| 12 | + if name == 'kiwi-cong': return KiwiModel(kiwi_model_path, 'cong') |
| 13 | + if name == 'kiwi-cong-global': return KiwiModel(kiwi_model_path, 'cong-global') |
| 14 | + if name == 'komoran': return KomoranModel() |
| 15 | + if name == 'kkma': return KkmaModel() |
| 16 | + if name == 'hannanum': return HannanumModel() |
| 17 | + if name == 'mecab': return MecabModel() |
| 18 | + if name == 'okt': return OktModel() |
| 19 | + if name == 'khaiii': return KhaiiiModel() |
| 20 | + if name == 'bareun': return BareunModel(bareun_api_key) |
| 21 | + raise ValueError(f'Unknown model name: {name}') |
| 22 | + |
| 23 | + def _convert(self, morph): |
| 24 | + return morph |
| 25 | + |
| 26 | + def _tokenize(self, text): |
| 27 | + raise NotImplementedError() |
| 28 | + |
| 29 | + def _is_noun(self, tag): |
| 30 | + return tag.startswith('NN') |
| 31 | + |
| 32 | + def tokenize(self, text): |
| 33 | + return list(map(self._convert, self._tokenize(text))) |
| 34 | + |
| 35 | + def nouns(self, text): |
| 36 | + return [form for form, tag in self.tokenize(text) if self._is_noun(tag)] |
| 37 | + |
| 38 | +class KiwiModel(Model): |
| 39 | + |
| 40 | + def __init__(self, model_path=None, model_type='none', **kwargs): |
| 41 | + import kiwipiepy |
| 42 | + from kiwipiepy import Kiwi |
| 43 | + print("Initialize kiwipiepy ({})".format(kiwipiepy.__version__), file=sys.stderr) |
| 44 | + self._mdl = Kiwi(model_path=model_path, model_type=model_type, **kwargs) |
| 45 | + self.oov_handling = None |
| 46 | + self.config = None |
| 47 | + |
| 48 | + def _convert(self, morph): |
| 49 | + return morph.form, morph.tag |
| 50 | + |
| 51 | + def _tokenize(self, text): |
| 52 | + return self._mdl.tokenize(text, oov_handling=self.oov_handling, override_config=self.config) |
| 53 | + |
| 54 | +class KomoranModel(Model): |
| 55 | + def __init__(self): |
| 56 | + import konlpy |
| 57 | + from konlpy import tag |
| 58 | + print("Initialize Komoran from konlpy ({})".format(konlpy.__version__), file=sys.stderr) |
| 59 | + self._mdl = tag.Komoran() |
| 60 | + |
| 61 | + def _tokenize(self, text): |
| 62 | + try: |
| 63 | + return self._mdl.pos(text) |
| 64 | + except: |
| 65 | + return [] |
| 66 | + |
| 67 | +class KkmaModel(Model): |
| 68 | + def __init__(self): |
| 69 | + import konlpy |
| 70 | + from konlpy import tag |
| 71 | + print("Initialize Kkma from konlpy ({})".format(konlpy.__version__), file=sys.stderr) |
| 72 | + self._mdl = tag.Kkma() |
| 73 | + |
| 74 | + def _tokenize(self, text): |
| 75 | + try: |
| 76 | + return self._mdl.pos(text) |
| 77 | + except: |
| 78 | + return [] |
| 79 | + |
| 80 | +class MecabModel(Model): |
| 81 | + def __init__(self): |
| 82 | + import konlpy |
| 83 | + from konlpy import tag |
| 84 | + print("Initialize Mecab from konlpy ({})".format(konlpy.__version__), file=sys.stderr) |
| 85 | + self._mdl = tag.Mecab() |
| 86 | + |
| 87 | + def _tokenize(self, text): |
| 88 | + try: |
| 89 | + return self._mdl.pos(text, split_inflect=True) |
| 90 | + except TypeError: |
| 91 | + return self._mdl.pos(text) |
| 92 | + |
| 93 | +class HannanumModel(Model): |
| 94 | + |
| 95 | + def __init__(self): |
| 96 | + import konlpy |
| 97 | + from konlpy import tag |
| 98 | + print("Initialize Hannanum from konlpy ({})".format(konlpy.__version__), file=sys.stderr) |
| 99 | + self._mdl = tag.Hannanum() |
| 100 | + |
| 101 | + def _convert(self, morph): |
| 102 | + if morph[1] == 'P': |
| 103 | + return morph[0], 'VV' |
| 104 | + return morph[0], morph[1] |
| 105 | + |
| 106 | + def _is_noun(self, tag): |
| 107 | + return tag == 'N' |
| 108 | + |
| 109 | + def _tokenize(self, text): |
| 110 | + return self._mdl.pos(text) |
| 111 | + |
| 112 | +class OktModel(Model): |
| 113 | + |
| 114 | + def __init__(self): |
| 115 | + import konlpy |
| 116 | + from konlpy import tag |
| 117 | + print("Initialize Okt from konlpy ({})".format(konlpy.__version__), file=sys.stderr) |
| 118 | + self._mdl = tag.Okt() |
| 119 | + |
| 120 | + def _convert(self, morph): |
| 121 | + if morph[1] == 'Verb': |
| 122 | + return morph[0][:-1], 'VV' |
| 123 | + return morph[0], morph[1] |
| 124 | + |
| 125 | + def _is_noun(self, tag): |
| 126 | + return tag in ('Noun', 'Foreign') |
| 127 | + |
| 128 | + def _tokenize(self, text): |
| 129 | + return self._mdl.pos(text, stem=True) |
| 130 | + |
| 131 | +class KhaiiiModel(Model): |
| 132 | + def __init__(self): |
| 133 | + from khaiii import KhaiiiApi |
| 134 | + self._mdl = KhaiiiApi() |
| 135 | + print("Initialize khaiii ({})".format(self._mdl.version()), file=sys.stderr) |
| 136 | + |
| 137 | + def _tokenize(self, text): |
| 138 | + return [(morph.lex, morph.tag) for word in self._mdl.analyze(text) for morph in word.morphs] |
| 139 | + |
| 140 | +class BareunModel(Model): |
| 141 | + def __init__(self, api_key, host='localhost', port=5656) -> None: |
| 142 | + import bareunpy as brn |
| 143 | + self._mdl = brn.Tagger(api_key, host, port) |
| 144 | + print(f"Initialize Bareun from bareunpy (version={brn.version}, bareun_version={brn.bareun_version})", file=sys.stderr) |
| 145 | + |
| 146 | + def _tokenize(self, text): |
| 147 | + return self._mdl.tag(text).pos() |
| 148 | + |
| 149 | +def load_dataset(path): |
| 150 | + tag_pattern = re.compile(r'<n(?: e="([^"]+)")?>(.*?)</n>') |
| 151 | + |
| 152 | + ret = [] |
| 153 | + for line in open(path, encoding='utf-8'): |
| 154 | + golds = Counter() |
| 155 | + gold_types = {} |
| 156 | + for m in tag_pattern.finditer(line): |
| 157 | + form = m.group(2) |
| 158 | + etype = m.group(1) |
| 159 | + gold_types[form] = etype |
| 160 | + golds[form] += 1 |
| 161 | + |
| 162 | + exam = tag_pattern.sub(r'\2', line).rstrip() |
| 163 | + ret.append((golds, gold_types, exam)) |
| 164 | + return ret |
| 165 | + |
| 166 | +def evaluate(dataset, model, score_by_type=False, result_output=None): |
| 167 | + gold_per_type = Counter() |
| 168 | + pred_per_type = Counter() |
| 169 | + correct_per_type = Counter() |
| 170 | + gold_chr_per_type = Counter() |
| 171 | + pred_chr_per_type = Counter() |
| 172 | + correct_chr_per_type = Counter() |
| 173 | + |
| 174 | + results = [] |
| 175 | + for golds, gold_types, exam in dataset: |
| 176 | + result = model.nouns(exam) |
| 177 | + if result_output is not None: |
| 178 | + tokens = model.tokenize(exam) |
| 179 | + print(' '.join(f'{form}/{t}' for form, t in gold_types.items() if t), |
| 180 | + exam, |
| 181 | + ' '.join(f'{form}/{tag}' for form, tag in tokens), sep='\t', file=result_output, flush=True) |
| 182 | + preds = Counter(r.replace(' ', '') for r in result) |
| 183 | + |
| 184 | + for form, count in golds.items(): |
| 185 | + gold_per_type[gold_types.get(form)] += count |
| 186 | + gold_chr_per_type[gold_types.get(form)] += len(form) * count |
| 187 | + |
| 188 | + for form, count in preds.items(): |
| 189 | + pred_per_type[gold_types.get(form)] += count |
| 190 | + pred_chr_per_type[gold_types.get(form)] += len(form) * count |
| 191 | + |
| 192 | + for form, count in (golds & preds).items(): |
| 193 | + correct_per_type[gold_types.get(form)] += count |
| 194 | + correct_chr_per_type[gold_types.get(form)] += len(form) * count |
| 195 | + |
| 196 | + results.append(result) |
| 197 | + |
| 198 | + all_types = sorted(filter(None, gold_per_type)) |
| 199 | + |
| 200 | + scores = {} |
| 201 | + scores['labeled_recall'] = sum(correct_per_type[t] for t in all_types) / max(sum(gold_per_type[t] for t in all_types), 1) |
| 202 | + if score_by_type: |
| 203 | + scores.update({f'labeled_recall {t}': correct_per_type[t] / max(gold_per_type[t], 1) for t in all_types}) |
| 204 | + |
| 205 | + scores['precision'] = p = sum(correct_per_type.values()) / max(sum(pred_per_type.values()), 1) |
| 206 | + scores['recall'] = r = sum(correct_per_type.values()) / max(sum(gold_per_type.values()), 1) |
| 207 | + scores['f1'] = 2 * p * r / max(p + r, 1) |
| 208 | + if score_by_type: |
| 209 | + for t in all_types: |
| 210 | + scores[f'precision {t}'] = p = correct_per_type[t] / max(pred_per_type[t], 1) |
| 211 | + scores[f'recall {t}'] = r = correct_per_type[t] / max(gold_per_type[t], 1) |
| 212 | + scores[f'f1 {t}'] = 2 * p * r / max(p + r, 1) |
| 213 | + |
| 214 | + scores['chr_precision'] = p = sum(correct_chr_per_type.values()) / max(sum(pred_chr_per_type.values()), 1) |
| 215 | + scores['chr_recall'] = r = sum(correct_chr_per_type.values()) / max(sum(gold_chr_per_type.values()), 1) |
| 216 | + scores['chr_f1'] = 2 * p * r / max(p + r, 1) |
| 217 | + if score_by_type: |
| 218 | + for t in all_types: |
| 219 | + scores[f'chr_precision {t}'] = p = correct_chr_per_type[t] / max(pred_chr_per_type[t], 1) |
| 220 | + scores[f'chr_recall {t}'] = r = correct_chr_per_type[t] / max(gold_chr_per_type[t], 1) |
| 221 | + scores[f'chr_f1 {t}'] = 2 * p * r / max(p + r, 1) |
| 222 | + return scores, results |
| 223 | + |
| 224 | +def test_kiwi_oov_handling(args): |
| 225 | + from kiwipiepy import KiwiConfig |
| 226 | + model = KiwiModel(model_path=args.kiwi_model_path, load_default_dict=args.test_kiwi_with_dictionary, load_multi_dict=args.test_kiwi_with_dictionary) |
| 227 | + |
| 228 | + settings = [ |
| 229 | + *[(f'rule (bias={bias})', {'oov_handling': 'rule', 'config': KiwiConfig(oov_rule_bias=bias)}) for bias in range(-3, 4)], |
| 230 | + *[(f'c (bias={bias})', {'oov_handling': 'chr', 'config': KiwiConfig(oov_chr_bias=bias)}) for bias in range(-3, 4)], |
| 231 | + *[(f'cf (bias={bias}, global_weight=35)', {'oov_handling': 'chr_freq', 'config': KiwiConfig(oov_chr_bias=bias, oov_global_weight=35)}) for bias in range(-3, 4)], |
| 232 | + *[(f'cf (bias=-3, global_weight={w})', {'oov_handling': 'chr_freq', 'config': KiwiConfig(oov_chr_bias=-3, oov_global_weight=w)}) for w in range(5, 105, 5)], |
| 233 | + ] |
| 234 | + |
| 235 | + if args.result_output: |
| 236 | + result_outputs = [open(f'{os.path.splitext(args.result_output)[0]}_{s[0]}{os.path.splitext(args.result_output)[1]}', 'w', encoding='utf-8') for s in settings] |
| 237 | + |
| 238 | + print('', '', *[s[0] for s in settings], sep='\t') |
| 239 | + for dataset in args.datasets: |
| 240 | + ds = load_dataset(dataset) |
| 241 | + all_scores = [] |
| 242 | + all_results = [ds] |
| 243 | + for i, (name, params) in enumerate(settings): |
| 244 | + model.oov_handling = params['oov_handling'] |
| 245 | + model.config = params.get('config') |
| 246 | + score, results = evaluate(ds, model, score_by_type=args.score_by_type, result_output=result_outputs[i] if args.result_output else None) |
| 247 | + all_scores.append(score) |
| 248 | + all_results.append(results) |
| 249 | + |
| 250 | + for key in score: |
| 251 | + print(os.path.basename(dataset), f'({key})', *((f'{s[key]:.4f}' if s[key] is not None else '-') for s in all_scores), sep='\t') |
| 252 | + |
| 253 | + if args.result_output: |
| 254 | + for f in result_outputs: |
| 255 | + f.close() |
| 256 | + |
| 257 | +def main(args): |
| 258 | + if args.test_kiwi_oov_handling: |
| 259 | + return test_kiwi_oov_handling(args) |
| 260 | + |
| 261 | + models = [Model.from_name(n, |
| 262 | + kiwi_model_path=args.kiwi_model_path, |
| 263 | + bareun_api_key=args.bareun_api_key) for n in args.target] |
| 264 | + |
| 265 | + if args.result_output: |
| 266 | + if len(models) == 1: |
| 267 | + result_outputs = [open(args.result_output, 'w', encoding='utf-8')] |
| 268 | + else: |
| 269 | + result_outputs = [open(f'{os.path.splitext(args.result_output)[0]}_{n}{os.path.splitext(args.result_output)[1]}', 'w', encoding='utf-8') for n in args.target] |
| 270 | + |
| 271 | + print('', '', *args.target, sep='\t') |
| 272 | + for dataset in args.datasets: |
| 273 | + ds = load_dataset(dataset) |
| 274 | + all_scores = [] |
| 275 | + all_results = [ds] |
| 276 | + for i, model in enumerate(models): |
| 277 | + score, results = evaluate(ds, model, score_by_type=args.score_by_type, result_output=result_outputs[i] if args.result_output else None) |
| 278 | + all_scores.append(score) |
| 279 | + all_results.append(results) |
| 280 | + |
| 281 | + for key in score: |
| 282 | + print(os.path.basename(dataset), f'({key})', *((f'{s[key]:.4f}' if s[key] is not None else '-') for s in all_scores), sep='\t') |
| 283 | + |
| 284 | + if args.result_output: |
| 285 | + for f in result_outputs: |
| 286 | + f.close() |
| 287 | + |
| 288 | +if __name__ == '__main__': |
| 289 | + import argparse |
| 290 | + |
| 291 | + parser = argparse.ArgumentParser() |
| 292 | + parser.add_argument('datasets', nargs='+') |
| 293 | + parser.add_argument('--target', default=['kiwi'], nargs='+', choices=['kiwi', 'kiwi-largest', 'komoran', 'mecab', 'kkma', 'hannanum', 'okt', 'khaiii', 'bareun']) |
| 294 | + parser.add_argument('--kiwi-model-path') |
| 295 | + parser.add_argument('--bareun-api-key') |
| 296 | + parser.add_argument('--test-kiwi-oov-handling', action='store_true') |
| 297 | + parser.add_argument('--test-kiwi-with-dictionary', action='store_true') |
| 298 | + parser.add_argument('--score-by-type', action='store_true') |
| 299 | + parser.add_argument('--result-output') |
| 300 | + main(parser.parse_args()) |
0 commit comments