Skip to content

Commit f44d176

Browse files
authored
Merge pull request #170 from bab2min/fix_pretokenized
Pretokenized span 개선
2 parents b23bfa4 + 69f5a21 commit f44d176

File tree

4 files changed

+55
-21
lines changed

4 files changed

+55
-21
lines changed

Diff for: include/kiwi/Types.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ namespace kiwi
241241
return irregular ? setIrregular(tag) : clearIrregular(tag);
242242
}
243243

244+
inline constexpr bool areTagsEqual(POSTag a, POSTag b, bool ignoreRegularity = false)
245+
{
246+
return ignoreRegularity ? (clearIrregular(a) == clearIrregular(b)) : (a == b);
247+
}
248+
244249
constexpr size_t defaultTagSize = (size_t)POSTag::p;
245250

246251
/**
@@ -349,9 +354,10 @@ namespace kiwi
349354
std::u16string form;
350355
uint32_t begin = -1, end = -1;
351356
POSTag tag = POSTag::unknown;
357+
uint8_t inferRegularity = 1;
352358

353-
BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown)
354-
: form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag }
359+
BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown, uint8_t _inferRegularity = 1)
360+
: form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag }, inferRegularity{ _inferRegularity }
355361
{}
356362
};
357363

Diff for: src/Joiner.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,7 @@ namespace kiwi
312312
Vector<const Morpheme*> cands;
313313
foreachMorpheme(formHead, [&](const Morpheme* m)
314314
{
315-
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag))
316-
{
317-
cands.emplace_back(m);
318-
}
319-
else if (!inferRegularity && m->tag == fixedTag)
315+
if (areTagsEqual(m->tag, fixedTag, inferRegularity))
320316
{
321317
cands.emplace_back(m);
322318
}
@@ -412,7 +408,7 @@ namespace kiwi
412408
Vector<const Morpheme*> cands;
413409
foreachMorpheme(formHead, [&](const Morpheme* m)
414410
{
415-
if (clearIrregular(m->tag) == clearIrregular(tag))
411+
if (areTagsEqual(m->tag, tag, true))
416412
{
417413
cands.emplace_back(m);
418414
}

Diff for: src/Kiwi.cpp

+17-13
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ namespace kiwi
247247
case POSTag::vcp:
248248
case POSTag::etm:
249249
case POSTag::ec:
250-
if (t.tag == POSTag::jx && *t.morph->kform == u"")
250+
if (t.tag == POSTag::jx && t.morph && *t.morph->kform == u"")
251251
{
252252
if (state == State::ef)
253253
{
@@ -804,36 +804,40 @@ namespace kiwi
804804
{
805805
auto formStr = normalizeHangul(s.tokenization[0].form);
806806
auto* tform = findForm(formTrie, formStr);
807-
if (tform && tform->candidate.size() == 1 && tform->candidate[0]->tag == s.tokenization[0].tag) // reuse the predefined form & morpheme
807+
if (tform && tform->candidate.size() == 1 &&
808+
areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity))
809+
// reuse the predefined form & morpheme
808810
{
809811
span.form = tform;
810812
}
811-
else if (formStr == normStr.substr(span.begin, span.end - span.begin)) // use a fallback form
812-
{
813-
span.form = formTrie.value((size_t)clearIrregular(s.tokenization[0].tag));
814-
}
815813
else // or add a new form & morpheme
816814
{
817815
ret.forms.emplace_back();
818816
auto& form = ret.forms.back();
819817
form.form = move(formStr);
820-
form.candidate = FixedVector<const Morpheme*>{ 1 };
821-
const Morpheme* foundMorph = nullptr;
818+
const Morpheme* foundMorph[2] = { nullptr, nullptr };
822819
if (tform)
823820
{
821+
size_t i = 0;
824822
for (auto m : tform->candidate)
825823
{
826-
if (m->tag == s.tokenization[0].tag)
824+
if (areTagsEqual(m->tag, s.tokenization[0].tag, s.tokenization[0].inferRegularity))
827825
{
828-
foundMorph = m;
829-
break;
826+
foundMorph[i++] = m;
827+
if (i >= 2) break;
830828
}
831829
}
832830
}
831+
832+
form.candidate = FixedVector<const Morpheme*>{ (size_t)(foundMorph[1] ? 2 : 1) };
833833

834-
if (foundMorph)
834+
if (foundMorph[0])
835835
{
836-
form.candidate[0] = foundMorph;
836+
form.candidate[0] = foundMorph[0];
837+
if (foundMorph[1])
838+
{
839+
form.candidate[1] = foundMorph[1];
840+
}
837841
}
838842
else
839843
{

Diff for: test/test_cpp.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,34 @@ TEST(KiwiCpp, Pretokenized)
271271
EXPECT_EQ(res[13].str, u"매트");
272272
EXPECT_EQ(res[13].tag, POSTag::nng);
273273
}
274+
275+
{
276+
std::vector<PretokenizedSpan> pretokenized = {
277+
PretokenizedSpan{ 9, 10, { BasicToken{ u"", 0, 1, POSTag::jks } } },
278+
PretokenizedSpan{ 16, 17, { BasicToken{ u"", 0, 1, POSTag::jkb } } },
279+
};
280+
281+
auto ref = kiwi.analyze(str, Match::allWithNormalizing).first;
282+
res = kiwi.analyze(str, Match::allWithNormalizing, nullptr, pretokenized).first;
283+
EXPECT_EQ(res[2].tag, POSTag::jks);
284+
EXPECT_EQ(res[2].morph, ref[2].morph);
285+
EXPECT_EQ(res[2].score, ref[2].score);
286+
EXPECT_EQ(res[5].tag, POSTag::jkb);
287+
EXPECT_EQ(res[5].morph, ref[5].morph);
288+
EXPECT_EQ(res[5].score, ref[5].score);
289+
}
290+
291+
{
292+
auto str2 = u"길을 걷다";
293+
std::vector<PretokenizedSpan> pretokenized = {
294+
PretokenizedSpan{ 3, 4, { BasicToken{ u"", 0, 1, POSTag::vv } } },
295+
};
296+
297+
auto ref = kiwi.analyze(str2, Match::allWithNormalizing).first;
298+
res = kiwi.analyze(str2, Match::allWithNormalizing, nullptr, pretokenized).first;
299+
EXPECT_EQ(res[2].tag, POSTag::vvi);
300+
EXPECT_EQ(res[2].morph, ref[2].morph);
301+
}
274302
}
275303

276304
TEST(KiwiCpp, TagRoundTrip)

0 commit comments

Comments
 (0)