Skip to content

Commit 2c0ee92

Browse files
committed
improved efficiency of splitting long text
1 parent 6f5c94b commit 2c0ee92

File tree

1 file changed

+44
-47
lines changed

1 file changed

+44
-47
lines changed

Diff for: src/KTrie.cpp

+44-47
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace kiwi
1717
{
1818
static constexpr uint32_t npos = -1;
1919

20-
if (endPosMap[startPos].first == npos)
20+
if (endPosMap[startPos].first == endPosMap[startPos].second)
2121
{
2222
return false;
2323
}
@@ -30,15 +30,15 @@ namespace kiwi
3030
nnode.prev = newId - endPosMap[startPos].first;
3131
if (nnode.endPos >= endPosMap.size()) return true;
3232

33-
if (endPosMap[nnode.endPos].first == npos)
33+
if (endPosMap[nnode.endPos].first == endPosMap[nnode.endPos].second)
3434
{
3535
endPosMap[nnode.endPos].first = newId;
36-
endPosMap[nnode.endPos].second = newId;
36+
endPosMap[nnode.endPos].second = newId + 1;
3737
}
3838
else
3939
{
40-
nodes[endPosMap[nnode.endPos].second].sibling = newId - endPosMap[nnode.endPos].second;
41-
endPosMap[nnode.endPos].second = newId;
40+
nodes[endPosMap[nnode.endPos].second - 1].sibling = newId - (endPosMap[nnode.endPos].second - 1);
41+
endPosMap[nnode.endPos].second = newId + 1;
4242
}
4343
return true;
4444
}
@@ -118,46 +118,34 @@ namespace kiwi
118118
return true;
119119
}
120120

121-
inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph)
121+
inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph, const Vector<std::pair<uint32_t, uint32_t>>& endPosMap)
122122
{
123-
Vector<uint8_t> connectedList(graph.size());
124-
Vector<uint16_t> newIndexDiff(graph.size());
125-
connectedList[graph.size() - 1] = true;
126-
connectedList[0] = true;
127-
// forward searching
128-
for (size_t i = 1; i < graph.size(); ++i)
129-
{
130-
bool connected = false;
131-
for (auto prev = graph[i].getPrev(); prev; prev = prev->getSibling())
123+
thread_local Vector<uint8_t> connectedList;
124+
thread_local Vector<uint16_t> newIndexDiff;
125+
thread_local Deque<uint32_t> updateList;
126+
connectedList.clear();
127+
connectedList.resize(graph.size());
128+
newIndexDiff.clear();
129+
newIndexDiff.resize(graph.size());
130+
updateList.clear();
131+
updateList.emplace_back(graph.size() - 1);
132+
connectedList[graph.size() - 1] = 1;
133+
134+
while (!updateList.empty())
135+
{
136+
const auto id = updateList.front();
137+
updateList.pop_front();
138+
const auto& node = graph[id];
139+
const auto scanStart = endPosMap[node.startPos].first, scanEnd = endPosMap[node.startPos].second;
140+
for (auto i = scanStart; i < scanEnd; ++i)
132141
{
133-
if (connectedList[prev - graph.data()])
134-
{
135-
connected = true;
136-
break;
137-
}
142+
if (graph[i].endPos != node.startPos) continue;
143+
if (connectedList[i]) continue;
144+
updateList.emplace_back(i);
138145
}
139-
connectedList[i] = connected ? 1 : 0;
140-
}
141-
// backward searching
142-
for (size_t i = graph.size() - 1; i-- > 1; )
143-
{
144-
bool connected = false;
145-
for (size_t j = i + 1; j < graph.size(); ++j)
146-
{
147-
for (auto prev = graph[j].getPrev(); prev; prev = prev->getSibling())
148-
{
149-
if (prev > &graph[i]) break;
150-
if (prev < &graph[i]) continue;
151-
if (connectedList[j])
152-
{
153-
connected = true;
154-
goto break_2;
155-
}
156-
}
157-
}
158-
break_2:
159-
connectedList[i] = (connectedList[i] && connected) ? 1 : 0;
146+
fill(connectedList.begin() + scanStart, connectedList.begin() + scanEnd, 1);
160147
}
148+
161149
size_t connectedCnt = accumulate(connectedList.begin(), connectedList.end(), 0);
162150
newIndexDiff[0] = connectedList[0];
163151
for (size_t i = 1; i < graph.size(); ++i)
@@ -231,10 +219,15 @@ size_t kiwi::splitByTrie(
231219
const PretokenizedSpanGroup::Span* pretokenizedLast
232220
)
233221
{
222+
/*
223+
* endPosMap[i]에는 out[x].endPos == i를 만족하는 첫번째 x(first)와 마지막 x + 1(second)가 들어 있다.
224+
* first == second인 경우 endPos가 i인 노드가 없다는 것을 의미한다.
225+
* first <= x && x < second인 out[x] 중에는 endPos가 i가 아닌 것도 있을 수 있으므로 주의해야 한다.
226+
*/
234227
thread_local Vector<pair<uint32_t, uint32_t>> endPosMap;
235228
endPosMap.clear();
236229
endPosMap.resize(str.size() + 1, make_pair<uint32_t, uint32_t>(-1, -1));
237-
endPosMap[0] = make_pair(0, 0);
230+
endPosMap[0] = make_pair(0, 1);
238231

239232
thread_local Vector<uint32_t> nonSpaces;
240233
nonSpaces.clear();
@@ -259,7 +252,8 @@ size_t kiwi::splitByTrie(
259252
for (auto& cand : candidates)
260253
{
261254
const size_t nBegin = typoTolerant ? candTypoCostStarts[&cand - candidates.data()].start : (nonSpaces.size() - cand->sizeWithoutSpace());
262-
const bool longestMatched = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
255+
const auto scanStart = max(endPosMap[nBegin].first, (uint32_t)1), scanEnd = endPosMap[nBegin].second;
256+
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
263257
{
264258
return nBegin == g.endPos && lastSpecialEndPos == g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
265259
});
@@ -335,7 +329,8 @@ size_t kiwi::splitByTrie(
335329
}
336330
}
337331

338-
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
332+
const auto scanStart = max(endPosMap[unkFormEndPos].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos].second;
333+
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
339334
{
340335
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
341336
return startPos == lastSpecialEndPos && g.endPos == unkFormEndPos;
@@ -483,7 +478,8 @@ size_t kiwi::splitByTrie(
483478
// sequence of speical characters found
484479
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && lastChrType != lastMatchedPattern)
485480
{
486-
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
481+
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
482+
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
487483
{
488484
return nonSpaces.size() == g.endPos;
489485
});
@@ -635,7 +631,8 @@ size_t kiwi::splitByTrie(
635631
// sequence of speical characters found
636632
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && !isWebTag(lastChrType))
637633
{
638-
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
634+
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
635+
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
639636
{
640637
return nonSpaces.size() == g.endPos;
641638
});
@@ -667,7 +664,7 @@ size_t kiwi::splitByTrie(
667664

668665
nonSpaces.emplace_back(n);
669666

670-
removeUnconnected(ret, out);
667+
removeUnconnected(ret, out, endPosMap);
671668
for (size_t i = 1; i < ret.size() - 1; ++i)
672669
{
673670
auto& r = ret[i];

0 commit comments

Comments
 (0)