Skip to content

Commit 2696a9f

Browse files
authored
Merge pull request #177 from bab2min/dev_fix176
Minor Fix including #176
2 parents 5a20486 + 90b121f commit 2696a9f

22 files changed

+807
-175
lines changed

Diff for: CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ set ( CORE_SRCS
5252
src/Joiner.cpp
5353
src/Kiwi.cpp
5454
src/KiwiBuilder.cpp
55+
src/Knlm.cpp
5556
src/KTrie.cpp
5657
src/PatternMatcher.cpp
5758
src/search.cpp

Diff for: include/kiwi/FrozenTrie.h

+27
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,26 @@ namespace kiwi
9797
std::unique_ptr<Key[]> nextKeys;
9898
std::unique_ptr<Diff[]> nextDiffs;
9999

100+
template<class Fn>
101+
void traverse(Fn&& visitor, const Node* node, std::vector<Key>& prefix, size_t maxDepth) const
102+
{
103+
auto* keys = &nextKeys[node->nextOffset];
104+
auto* diffs = &nextDiffs[node->nextOffset];
105+
for (size_t i = 0; i < node->numNexts; ++i)
106+
{
107+
const auto* child = node + diffs[i];
108+
const auto val = child->val(*this);
109+
if (!hasMatch(val)) continue;
110+
prefix.emplace_back(keys[i]);
111+
visitor(val, prefix);
112+
if (prefix.size() < maxDepth)
113+
{
114+
traverse(visitor, child, prefix, maxDepth);
115+
}
116+
prefix.pop_back();
117+
}
118+
}
119+
100120
public:
101121

102122
FrozenTrie() = default;
@@ -117,6 +137,13 @@ namespace kiwi
117137
const Value& value(size_t idx) const { return values[idx]; };
118138

119139
bool hasMatch(_Value v) const { return !this->isNull(v) && !this->hasSubmatch(v); }
140+
141+
template<class Fn>
142+
void traverse(Fn&& visitor, size_t maxDepth = -1) const
143+
{
144+
std::vector<Key> prefix;
145+
traverse(std::forward<Fn>(visitor), root(), prefix, maxDepth);
146+
}
120147
};
121148
}
122149
}

Diff for: include/kiwi/Knlm.h

+29-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#include <algorithm>
88
#include <numeric>
99

10+
#include "Utils.h"
1011
#include "Mmap.h"
12+
#include "ArchUtils.h"
1113

1214
namespace kiwi
1315
{
@@ -20,6 +22,7 @@ namespace kiwi
2022
uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset;
2123
uint64_t unk_id, bos_id, eos_id, vocab_size;
2224
uint8_t order, key_size, diff_size, quantized;
25+
uint32_t extra_buf_size;
2326
};
2427

2528
template<class KeyType, class DiffType = int32_t>
@@ -43,6 +46,7 @@ namespace kiwi
4346
virtual float _progress(ptrdiff_t& node_idx, size_t next) const = 0;
4447
virtual std::vector<float> allNextLL(ptrdiff_t node_idx) const = 0;
4548
virtual std::vector<float> allNextLL(ptrdiff_t node_idx, std::vector<ptrdiff_t>& next_node_idx) const = 0;
49+
virtual void nextTopN(ptrdiff_t node_idx, size_t top_n, uint32_t* idx_out, float* ll_out) const = 0;
4650

4751
public:
4852

@@ -55,21 +59,28 @@ namespace kiwi
5559
virtual size_t llSize() const = 0;
5660
virtual const float* getLLBuf() const = 0;
5761
virtual const float* getGammaBuf() const = 0;
62+
virtual const void* getExtraBuf() const = 0;
5863

5964
static std::unique_ptr<KnLangModelBase> create(utils::MemoryObject&& mem, ArchType archType = ArchType::none);
6065

61-
template<class TrieNode, class HistoryTx = std::vector<Vid>>
62-
static utils::MemoryOwner build(const utils::ContinuousTrie<TrieNode>& ngram_cf,
63-
size_t order, size_t min_cf, size_t last_min_cf,
66+
template<class Trie, class HistoryTx = std::vector<Vid>>
67+
static utils::MemoryOwner build(Trie&& ngram_cf,
68+
size_t order, const std::vector<size_t>& min_cf_by_order,
6469
size_t unk_id, size_t bos_id, size_t eos_id,
6570
float unigram_alpha, size_t quantize, bool compress,
6671
const std::vector<std::pair<Vid, Vid>>* bigram_list = nullptr,
67-
const HistoryTx* historyTransformer = nullptr
72+
const HistoryTx* history_transformer = nullptr,
73+
const void* extra_buf = nullptr,
74+
size_t extra_buf_size = 0
6875
);
6976

7077
const utils::MemoryObject& getMemory() const { return base; }
7178

72-
//virtual float progress(ptrdiff_t& node_idx, size_t next) const = 0;
79+
template<class Ty>
80+
float progress(ptrdiff_t& node_idx, Ty next) const
81+
{
82+
return _progress(node_idx, next);
83+
}
7384

7485
template<class InTy, class OutTy>
7586
void evaluate(InTy in_first, InTy in_last, OutTy out_first) const
@@ -130,6 +141,19 @@ namespace kiwi
130141
}
131142
}
132143

144+
template<class InTy>
145+
void predictTopN(InTy in_first, InTy in_last, size_t top_n, uint32_t* idx_out, float* ll_out) const
146+
{
147+
ptrdiff_t node_idx = 0;
148+
for (; in_first != in_last; ++in_first)
149+
{
150+
_progress(node_idx, *in_first);
151+
nextTopN(node_idx, top_n, idx_out, ll_out);
152+
idx_out += top_n;
153+
ll_out += top_n;
154+
}
155+
}
156+
133157
template<class PfTy, class SfTy, class OutTy>
134158
void fillIn(PfTy prefix_first, PfTy prefix_last, SfTy suffix_first, SfTy suffix_last, OutTy out_first, bool reduce = true) const
135159
{

Diff for: include/kiwi/Mmap.h

+5
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ namespace kiwi
303303
setp(epptr() + off, epptr());
304304
else if (dir == std::ios_base::beg)
305305
setp(pbase() + off, epptr());
306+
307+
if (!(which & std::ios_base::in))
308+
{
309+
return pptr() - pbase();
310+
}
306311
}
307312
return gptr() - eback();
308313
}

Diff for: include/kiwi/SubstringExtractor.h

+49
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include <vector>
44
#include <string>
55

6+
#include <kiwi/FrozenTrie.h>
7+
#include <kiwi/Knlm.h>
8+
69
namespace kiwi
710
{
811
std::vector<std::pair<std::u16string, size_t>> extractSubstrings(
@@ -13,4 +16,50 @@ namespace kiwi
1316
size_t maxLength = 32,
1417
bool longestOnly = true,
1518
char16_t stopChr = 0);
19+
20+
21+
class PrefixCounter
22+
{
23+
size_t prefixSize = 0, minCf = 0, numArrays = 0;
24+
UnorderedMap<uint32_t, uint32_t> token2id;
25+
Vector<uint32_t> id2Token;
26+
Vector<uint16_t> buf;
27+
Vector<size_t> tokenClusters;
28+
Vector<size_t> tokenCnts;
29+
std::shared_ptr<void> threadPool;
30+
31+
template<class It>
32+
void _addArray(It first, It last);
33+
34+
Vector<std::pair<uint32_t, float>> computeClusterScore() const;
35+
36+
public:
37+
PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers,
38+
const std::vector<std::vector<size_t>>& clusters = {}
39+
);
40+
void addArray(const uint16_t* first, const uint16_t* last);
41+
void addArray(const uint32_t* first, const uint32_t* last);
42+
void addArray(const uint64_t* first, const uint64_t* last);
43+
utils::FrozenTrie<uint32_t, uint32_t> count() const;
44+
std::unique_ptr<lm::KnLangModelBase> buildLM(
45+
const std::vector<size_t>& minCfByOrder,
46+
size_t bosTokenId,
47+
size_t eosTokenId,
48+
size_t unkTokenId,
49+
ArchType archType = ArchType::none
50+
) const;
51+
};
52+
53+
class ClusterData
54+
{
55+
const std::pair<uint32_t, float>* clusterScores = nullptr;
56+
size_t clusterSize = 0;
57+
public:
58+
ClusterData();
59+
ClusterData(const void* _ptr, size_t _size);
60+
61+
size_t size() const;
62+
size_t cluster(size_t i) const;
63+
float score(size_t i) const;
64+
};
1665
}

Diff for: include/kiwi/Trie.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,24 @@ namespace kiwi
132132
return;
133133
}
134134

135+
template<typename _Fn, typename _CKey>
136+
void traverse(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
137+
{
138+
fn(this->val, rkeys);
139+
140+
if (rkeys.size() >= maxDepth) return;
141+
142+
for (auto& p : next)
143+
{
144+
if (ignoreNegative ? (p.second > 0) : (p.second))
145+
{
146+
rkeys.emplace_back(p.first);
147+
getNext(p.first)->traverse(fn, rkeys, maxDepth, ignoreNegative);
148+
rkeys.pop_back();
149+
}
150+
}
151+
}
152+
135153
template<typename _Fn, typename _CKey>
136154
void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
137155
{
@@ -462,6 +480,13 @@ namespace kiwi
462480
return nodes[0].fillFail(std::forward<HistoryTx>(htx), ignoreNegative);
463481
}
464482

483+
template<typename _Fn>
484+
void traverse(_Fn&& fn, size_t maxDepth = -1, bool ignoreNegative = false) const
485+
{
486+
std::vector<typename Node::Key> rkeys;
487+
return nodes[0].traverse(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative);
488+
}
489+
465490
template<typename _Fn, typename _CKey>
466491
void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
467492
{

Diff for: include/kiwi/Utils.h

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
#pragma once
1+
#pragma once
22
#include <iostream>
33
#include <string>
44
#include <memory>
5+
#include <array>
56
#include "Types.h"
67

78
namespace kiwi
@@ -82,6 +83,11 @@ namespace kiwi
8283
return within(chr, 0x302E, 0x3030);
8384
}
8485

86+
inline bool isCompatibleHangulConsonant(char16_t chr)
87+
{
88+
return within(chr, 0x3131, 0x314E) || within(chr, 0x3165, 0x3186);
89+
}
90+
8591
struct ComparatorIgnoringSpace
8692
{
8793
static bool less(const KString& a, const KString& b, const kchar_t space = u' ');
@@ -146,6 +152,38 @@ namespace kiwi
146152
return joinHangul(hangul.begin(), hangul.end());
147153
}
148154

155+
inline bool isHighSurrogate(char16_t c)
156+
{
157+
return (c & 0xFC00) == 0xD800;
158+
}
159+
160+
inline bool isLowSurrogate(char16_t c)
161+
{
162+
return (c & 0xFC00) == 0xDC00;
163+
}
164+
165+
inline char32_t mergeSurrogate(char16_t h, char16_t l)
166+
{
167+
return (((h & 0x3FF) << 10) | (l & 0x3FF)) + 0x10000;
168+
}
169+
170+
inline std::array<char16_t, 2> decomposeSurrogate(char32_t c)
171+
{
172+
std::array<char16_t, 2> ret;
173+
if (c < 0x10000)
174+
{
175+
ret[0] = c;
176+
ret[1] = 0;
177+
}
178+
else
179+
{
180+
c -= 0x10000;
181+
ret[0] = ((c >> 10) & 0x3FF) | 0xD800;
182+
ret[1] = (c & 0x3FF) | 0xDC00;
183+
}
184+
return ret;
185+
}
186+
149187
POSTag identifySpecialChr(char32_t chr);
150188
size_t getSSType(char16_t c);
151189
size_t getSBType(const std::u16string& form);

Diff for: src/FrozenTrie.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ namespace kiwi
111111
for (size_t i = 0; i < trie.size(); ++i)
112112
{
113113
auto& o = trie[i];
114-
nodes[i].numNexts = o.next.size();
114+
nodes[i].numNexts = (Key)o.next.size();
115115
values[i] = xform(o);
116116
nodes[i].nextOffset = ptr;
117117

Diff for: src/Joiner.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ namespace kiwi
352352
for (size_t i = 0; i < candidates.size(); ++i)
353353
{
354354
auto& c = candidates[i];
355-
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i));
355+
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i));
356356
if (!inserted.second)
357357
{
358358
if (inserted.first->second.first < c.score)

Diff for: src/Kiwi.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ namespace kiwi
847847
morph.tag = s.tokenization[0].tag;
848848
morph.vowel = CondVowel::none;
849849
morph.polar = CondPolarity::none;
850+
morph.complex = 0;
850851
morph.lmMorphemeId = getDefaultMorphemeId(s.tokenization[0].tag);
851852
form.candidate[0] = &morph;
852853
}

Diff for: src/KiwiBuilder.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -723,10 +723,11 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args)
723723
new (&pool) utils::ThreadPool{ args.numWorkers };
724724
}
725725
auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr);
726-
cntNodes.root().getNext(lmVocabSize)->val /= 2;
726+
std::vector<size_t> minCnts(args.lmOrder, args.lmMinCnt);
727+
minCnts.back() = args.lmLastOrderMinCnt;
727728
langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build(
728729
cntNodes,
729-
args.lmOrder, args.lmMinCnt, args.lmLastOrderMinCnt,
730+
args.lmOrder, minCnts,
730731
2, 0, 1, 1e-5,
731732
args.quantizeLm ? 8 : 0,
732733
args.compressLm,

Diff for: src/Knlm.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "Knlm.hpp"
2+
3+
namespace kiwi
4+
{
5+
namespace lm
6+
{
7+
template<ArchType archType>
8+
std::unique_ptr<KnLangModelBase> createOptimizedModel(utils::MemoryObject&& mem)
9+
{
10+
auto* ptr = reinterpret_cast<const char*>(mem.get());
11+
auto& header = *reinterpret_cast<const Header*>(ptr);
12+
switch (header.key_size)
13+
{
14+
case 1:
15+
return make_unique<KnLangModel<archType, uint8_t>>(std::move(mem));
16+
case 2:
17+
return make_unique<KnLangModel<archType, uint16_t>>(std::move(mem));
18+
case 4:
19+
return make_unique<KnLangModel<archType, uint32_t>>(std::move(mem));
20+
case 8:
21+
return make_unique<KnLangModel<archType, uint64_t>>(std::move(mem));
22+
default:
23+
throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) };
24+
}
25+
}
26+
27+
using FnCreateOptimizedModel = decltype(&createOptimizedModel<ArchType::none>);
28+
29+
struct CreateOptimizedModelGetter
30+
{
31+
template<std::ptrdiff_t i>
32+
struct Wrapper
33+
{
34+
static constexpr FnCreateOptimizedModel value = &createOptimizedModel<static_cast<ArchType>(i)>;
35+
};
36+
};
37+
38+
std::unique_ptr<KnLangModelBase> KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType)
39+
{
40+
static tp::Table<FnCreateOptimizedModel, AvailableArch> table{ CreateOptimizedModelGetter{} };
41+
auto fn = table[static_cast<std::ptrdiff_t>(archType)];
42+
if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) };
43+
return (*fn)(std::move(mem));
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)