Skip to content

Commit 0d72a27

Browse files
authored
Merge pull request #194 from bab2min/dev/issue192
자잘한 버그 수정
2 parents c1da90c + f5bfccb commit 0d72a27

File tree

10 files changed

+236
-65
lines changed

10 files changed

+236
-65
lines changed

Diff for: include/kiwi/Dataset.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,19 @@ namespace kiwi
4747
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
4848
std::shared_ptr<lm::KnLangModelBase> knlm;
4949
std::unique_ptr<utils::ThreadPool> workers;
50+
std::shared_ptr<KiwiBuilder> dummyBuilder;
5051
std::discrete_distribution<> dropout;
5152
std::mt19937_64 rng;
5253
Vector<ThreadLocal> locals;
5354
Vector<size_t> shuffledIdx;
5455
Vector<int32_t> tokenToVocab, vocabToToken;
56+
Vector<uint8_t> windowTokenValidness;
5557
Deque<OptionalFuture<size_t>> futures;
5658
const Vector<MorphemeRaw>* morphemes = nullptr;
5759
const Vector<FormRaw>* forms = nullptr;
5860
size_t knlmVocabSize = 0;
5961
size_t batchSize = 0;
62+
size_t causalContextSize = 0;
6063
size_t windowSize = 0;
6164
size_t totalTokens = 0;
6265
size_t passedSents = 0;
@@ -68,7 +71,7 @@ namespace kiwi
6871
size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut);
6972

7073
public:
71-
HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
74+
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
7275
~HSDataset();
7376
HSDataset(const HSDataset&) = delete;
7477
HSDataset(HSDataset&&) /*noexcept*/;
@@ -80,7 +83,9 @@ namespace kiwi
8083
size_t numTokens() const;
8184

8285
size_t getBatchSize() const { return batchSize; }
86+
size_t getCausalContextSize() const { return causalContextSize; }
8387
size_t getWindowSize() const { return windowSize; }
88+
const Vector<uint8_t>& getWindowTokenValidness() const { return windowTokenValidness; }
8489

8590
void seed(size_t newSeed);
8691
void reset();

Diff for: include/kiwi/Kiwi.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,8 @@ namespace kiwi
548548

549549
using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;
550550

551+
void initMorphemes();
552+
551553
template<class Fn>
552554
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);
553555

@@ -612,8 +614,7 @@ namespace kiwi
612614
std::vector<std::string> corpora;
613615
size_t minMorphCnt = 10;
614616
size_t lmOrder = 4;
615-
size_t lmMinCnt = 1;
616-
size_t lmLastOrderMinCnt = 2;
617+
std::vector<size_t> lmMinCnts = { 1 };
617618
size_t numWorkers = 1;
618619
size_t sbgSize = 1000000;
619620
bool useLmTagHistory = true;
@@ -801,11 +802,14 @@ namespace kiwi
801802
using TokenFilter = std::function<bool(const std::u16string&, POSTag)>;
802803

803804
HSDataset makeHSDataset(const std::vector<std::string>& inputPathes,
804-
size_t batchSize, size_t windowSize, size_t numWorkers,
805+
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
805806
double dropoutProb = 0,
806807
const TokenFilter& tokenFilter = {},
808+
const TokenFilter& windowFilter = {},
807809
double splitRatio = 0,
808810
bool separateDefaultMorpheme = false,
811+
const std::string& morphemeDefPath = {},
812+
size_t morphemeDefMinCnt = 0,
809813
HSDataset* splitDataset = nullptr
810814
) const;
811815
};

Diff for: include/kiwi/TagUtils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace kiwi
3535
inline bool isSuffix(POSTag tag)
3636
{
3737
tag = clearIrregular(tag);
38-
return POSTag::xsn <= tag && tag <= POSTag::xsa;
38+
return POSTag::xsn <= tag && tag <= POSTag::xsm;
3939
}
4040

4141
inline bool isSpecialClass(POSTag tag)

Diff for: src/Dataset.cpp

+40-11
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
using namespace kiwi;
55

6-
HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb)
6+
HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb)
77
: workers{ _workers ? make_unique<utils::ThreadPool>(_workers) : nullptr },
88
dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} },
99
locals( _workers ? workers->size() : 1),
1010
batchSize{ _batchSize },
11+
causalContextSize{ _causalContextSize },
1112
windowSize{ _windowSize }
1213
{
1314
}
@@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
113114

114115
local.lmLProbsBuf.resize(tokens.size());
115116
local.outNgramNodeBuf.resize(tokens.size());
116-
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
117+
if (knlm)
118+
{
119+
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
120+
}
117121

118122
auto& history = local.historyBuf;
119123
history.clear();
120-
history.resize(windowSize, -1);
121-
history.back() = tokenToVocab[tokens[0]];
124+
if (windowSize)
125+
{
126+
history.resize(windowSize, -1);
127+
if (windowTokenValidness[tokens[0]])
128+
{
129+
history.back() = tokenToVocab[tokens[0]];
130+
}
131+
}
122132
for (size_t i = 1; i < tokens.size(); ++i)
123133
{
124134
int32_t v = tokenToVocab[tokens[i]];
@@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
134144
local.restLmLProbsCntData[r] += 1;
135145
continue;
136146
}
137-
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));
147+
148+
if (causalContextSize)
149+
{
150+
for (size_t j = 0; j < causalContextSize; ++j)
151+
{
152+
local.inData.emplace_back(i + j < causalContextSize ?
153+
nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]);
154+
}
155+
}
156+
if (windowSize)
157+
{
158+
if (windowTokenValidness[v])
159+
{
160+
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));
161+
history.pop_front();
162+
history.push_back(v);
163+
}
164+
else
165+
{
166+
local.inData.resize(local.inData.size() + windowSize, -1);
167+
}
168+
}
169+
138170
local.outData.emplace_back(v);
139171
local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]);
140172
local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]);
141-
142-
history.pop_front();
143-
history.push_back(v);
144173
}
145174

146175
size_t r = local.outData.size() / batchSize;
@@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
217246
auto& l = locals[localId];
218247

219248
size_t rest = std::min(l.outData.size(), batchSize);
220-
std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in);
249+
std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in);
221250
std::copy(l.outData.begin(), l.outData.begin() + rest, out);
222251
std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs);
223252
std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode);
224253
restLmOut = l.restLmLProbsData.front();
225254
restLmCntOut = l.restLmLProbsCntData.front();
226255

227-
l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize);
256+
l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize));
228257
l.outData.erase(l.outData.begin(), l.outData.begin() + rest);
229258
l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest);
230259
l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest);
@@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN
245274

246275
size_t HSDataset::ngramNodeSize() const
247276
{
248-
return knlm->nonLeafNodeSize();
277+
return knlm ? knlm->nonLeafNodeSize() : 0;
249278
}
250279

251280
const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const

Diff for: src/KTrie.cpp

+19-9
Original file line numberDiff line numberDiff line change
@@ -653,8 +653,8 @@ size_t kiwi::splitByTrie(
653653
const auto scanStart = max(endPosMap[nBeginWithMultiplier].first, (uint32_t)1), scanEnd = endPosMap[nBeginWithMultiplier].second;
654654
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
655655
{
656-
const auto start = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
657-
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos == start || specialStartPos == start);
656+
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
657+
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos * posMultiplier == startPos || specialStartPos * posMultiplier == startPos);
658658
});
659659

660660
// insert unknown form
@@ -742,7 +742,7 @@ size_t kiwi::splitByTrie(
742742
const auto scanStart = max(endPosMap[unkFormEndPos * posMultiplier].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos * posMultiplier].second;
743743
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
744744
{
745-
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
745+
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
746746
return startPos == lastSpecialEndPos * posMultiplier && g.endPos == unkFormEndPos * posMultiplier;
747747
});
748748
if (unkFormEndPos > lastSpecialEndPos && !duplicated)
@@ -1215,9 +1215,10 @@ size_t kiwi::splitByTrie(
12151215
return n + startOffset;
12161216
}
12171217

1218-
template<ArchType arch>
1218+
template<ArchType arch, bool typoTolerant>
12191219
const Form* kiwi::findForm(
12201220
const utils::FrozenTrie<kchar_t, const Form*>& trie,
1221+
const Form* formData,
12211222
const KString& str
12221223
)
12231224
{
@@ -1228,7 +1229,12 @@ const Form* kiwi::findForm(
12281229
if (!node) return nullptr;
12291230
}
12301231
if (trie.hasSubmatch(node->val(trie))) return nullptr;
1231-
return node->val(trie);
1232+
auto ret = node->val(trie);
1233+
if (typoTolerant)
1234+
{
1235+
ret = &reinterpret_cast<const TypoForm*>(ret)->form(formData);
1236+
}
1237+
return ret;
12321238
}
12331239

12341240
namespace kiwi
@@ -1266,19 +1272,23 @@ FnSplitByTrie kiwi::getSplitByTrieFn(ArchType arch, bool typoTolerant, bool cont
12661272

12671273
namespace kiwi
12681274
{
1275+
template<bool typoTolerant>
12691276
struct FindFormGetter
12701277
{
12711278
template<std::ptrdiff_t i>
12721279
struct Wrapper
12731280
{
1274-
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i)>;
1281+
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i), typoTolerant>;
12751282
};
12761283
};
12771284
}
12781285

1279-
FnFindForm kiwi::getFindFormFn(ArchType arch)
1286+
FnFindForm kiwi::getFindFormFn(ArchType arch, bool typoTolerant)
12801287
{
1281-
static tp::Table<FnFindForm, AvailableArch> table{ FindFormGetter{} };
1288+
static std::array<tp::Table<FnFindForm, AvailableArch>, 2> table{
1289+
FindFormGetter<false>{},
1290+
FindFormGetter<true>{},
1291+
};
12821292

1283-
return table[static_cast<std::ptrdiff_t>(arch)];
1293+
return table[typoTolerant ? 1 : 0][static_cast<std::ptrdiff_t>(arch)];
12841294
}

Diff for: src/KTrie.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,18 @@ namespace kiwi
103103
const PretokenizedSpanGroup::Span* pretokenizedLast
104104
);
105105

106-
template<ArchType arch>
106+
template<ArchType arch, bool typoTolerant>
107107
const Form* findForm(
108108
const utils::FrozenTrie<kchar_t, const Form*>& trie,
109+
const Form* formData,
109110
const KString& str
110111
);
111112

112113
using FnSplitByTrie = decltype(&splitByTrie<ArchType::default_>);
113114
FnSplitByTrie getSplitByTrieFn(ArchType arch, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant);
114115

115-
using FnFindForm = decltype(&findForm<ArchType::default_>);
116-
FnFindForm getFindFormFn(ArchType arch);
116+
using FnFindForm = decltype(&findForm<ArchType::default_, false>);
117+
FnFindForm getFindFormFn(ArchType arch, bool typoTolerant);
117118

118119
struct KTrie : public utils::TrieNode<char16_t, const Form*, utils::ConstAccess<map<char16_t, int32_t>>, KTrie>
119120
{

Diff for: src/Kiwi.cpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ namespace kiwi
5353
typoTolerant,
5454
continualTypoTolerant,
5555
lengtheningTypoTolerant);
56-
dfFindForm = (void*)getFindFormFn(selectedArch);
56+
dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant);
5757

5858
static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_8{ FindBestPathGetter<WrappedKnLM<uint8_t>::type>{} };
5959
static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_16{ FindBestPathGetter<WrappedKnLM<uint16_t>::type>{} };
@@ -802,7 +802,8 @@ namespace kiwi
802802
const Vector<uint32_t>& positionTable,
803803
const KString& normStr,
804804
FnFindForm findForm,
805-
const utils::FrozenTrie<kchar_t, const Form*>& formTrie
805+
const utils::FrozenTrie<kchar_t, const Form*>& formTrie,
806+
const Form* formData
806807
)
807808
{
808809
if (pretokenized.empty()) return;
@@ -833,7 +834,7 @@ namespace kiwi
833834
if (s.tokenization.empty())
834835
{
835836
auto formStr = normStr.substr(span.begin, span.end - span.begin);
836-
span.form = findForm(formTrie, formStr); // reuse the predefined form & morpheme
837+
span.form = findForm(formTrie, formData, formStr); // reuse the predefined form & morpheme
837838
if (!span.form) // or use a fallback form
838839
{
839840
span.form = formTrie.value((size_t)POSTag::nnp);
@@ -842,7 +843,7 @@ namespace kiwi
842843
else if (s.tokenization.size() == 1)
843844
{
844845
auto formStr = normalizeHangul(s.tokenization[0].form);
845-
auto* tform = findForm(formTrie, formStr);
846+
auto* tform = findForm(formTrie, formData, formStr);
846847
if (tform && tform->candidate.size() == 1 &&
847848
areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity))
848849
// reuse the predefined form & morpheme
@@ -908,7 +909,7 @@ namespace kiwi
908909
{
909910
auto& t = s.tokenization[i];
910911
auto formStr = normalizeHangul(t.form);
911-
auto* tform = findForm(formTrie, formStr);
912+
auto* tform = findForm(formTrie, formData, formStr);
912913
const Morpheme* foundMorph = nullptr;
913914
if (tform)
914915
{
@@ -999,7 +1000,8 @@ namespace kiwi
9991000
positionTable,
10001001
normalizedStr,
10011002
reinterpret_cast<FnFindForm>(dfFindForm),
1002-
formTrie
1003+
formTrie,
1004+
forms.data()
10031005
);
10041006

10051007
// 분석할 문장에 포함된 개별 문자에 대해 어절번호를 생성한다
@@ -1317,7 +1319,7 @@ namespace kiwi
13171319
void Kiwi::findMorpheme(vector<const Morpheme*>& ret, const u16string& s, POSTag tag) const
13181320
{
13191321
auto normalized = normalizeHangul(s);
1320-
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, normalized);
1322+
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, forms.data(), normalized);
13211323
if (!form) return;
13221324
tag = clearIrregular(tag);
13231325
for (auto c : form->candidate)

0 commit comments

Comments
 (0)