Skip to content

Commit 6543a1f

Browse files
committed
Fix compilation errors
1 parent 00d7a2e commit 6543a1f

File tree

3 files changed

+199
-199
lines changed

3 files changed

+199
-199
lines changed

Diff for: src/KiwiBuilder.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
#include <kiwi/Kiwi.h>
55
#include <kiwi/Utils.h>
66
#include <kiwi/Dataset.h>
7+
#include <kiwi/Knlm.h>
78
#include "ArchAvailable.h"
89
#include "KTrie.h"
910
#include "StrUtils.h"
1011
#include "FrozenTrie.hpp"
11-
#include "Knlm.hpp"
1212
#include "serializer.hpp"
1313
#include "count.hpp"
1414
#include "FeatureTestor.h"

Diff for: src/Knlm.cpp

-198
Original file line numberDiff line numberDiff line change
@@ -7,204 +7,6 @@ namespace kiwi
77
{
88
namespace lm
99
{
10-
template<ArchType arch, class KeyType, bool transposed, class DiffType>
11-
template<ptrdiff_t ...idx>
12-
void KnLangModel<arch, KeyType, transposed, DiffType>::dequantizeDispatch(
13-
tp::seq<idx...>,
14-
size_t bits,
15-
Vector<float>& restored_floats, Vector<float>& restored_leaf_ll,
16-
const char* llq_data, size_t llq_size,
17-
const char* gammaq_data, size_t gammaq_size,
18-
const float* ll_table,
19-
const float* gamma_table,
20-
size_t num_non_leaf_nodes,
21-
size_t num_leaf_nodes
22-
)
23-
{
24-
using Fn = void(*)(Vector<float>&, Vector<float>&,
25-
const char*, size_t,
26-
const char*, size_t,
27-
const float*,
28-
const float*,
29-
size_t,
30-
size_t);
31-
static constexpr Fn table[] = {
32-
&dequantize<idx + 1>...
33-
};
34-
return table[bits - 1](restored_floats, restored_leaf_ll,
35-
llq_data, llq_size,
36-
gammaq_data, gammaq_size,
37-
ll_table, gamma_table,
38-
num_non_leaf_nodes, num_leaf_nodes
39-
);
40-
}
41-
42-
template<ArchType arch, class KeyType, bool transposed, class DiffType>
43-
KnLangModel<arch, KeyType, transposed, DiffType>::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) }
44-
{
45-
auto* ptr = reinterpret_cast<const char*>(base.get());
46-
auto& header = getHeader();
47-
const size_t quantized = header.quantized & 0x1F;
48-
const bool compressed = header.quantized & 0x80;
49-
50-
Vector<KeyType> d_node_size;
51-
auto* node_sizes = reinterpret_cast<const KeyType*>(ptr + header.node_offset);
52-
key_data = make_unique<KeyType[]>((header.ll_offset - header.key_offset) / sizeof(KeyType));
53-
std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset);
54-
size_t num_leaf_nodes = 0;
55-
if (compressed)
56-
{
57-
d_node_size.resize(header.num_nodes);
58-
auto qc_header = reinterpret_cast<const uint8_t*>(ptr + header.node_offset);
59-
auto qc_body = reinterpret_cast<const size_t*>(qc_header + (header.num_nodes + 3) / 4);
60-
QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes);
61-
node_sizes = d_node_size.data();
62-
}
63-
64-
for (size_t i = 0; i < header.num_nodes; ++i)
65-
{
66-
if (node_sizes[i]) num_non_leaf_nodes++;
67-
else num_leaf_nodes++;
68-
}
69-
70-
// restore ll & gamma data
71-
Vector<float> restored_leaf_ll, restored_floats;
72-
const float* ll_data = nullptr;
73-
const float* gamma_data = nullptr;
74-
const float* leaf_ll_data = nullptr;
75-
if (quantized)
76-
{
77-
if (quantized > 16)
78-
{
79-
throw std::runtime_error{ "16+ bits quantization not supported." };
80-
}
81-
82-
restored_floats.resize(num_non_leaf_nodes * 2);
83-
restored_leaf_ll.resize(num_leaf_nodes);
84-
leaf_ll_data = restored_leaf_ll.data();
85-
ll_data = &restored_floats[0];
86-
gamma_data = &restored_floats[num_non_leaf_nodes];
87-
88-
const float* ll_table = reinterpret_cast<const float*>(ptr + header.qtable_offset);
89-
const float* gamma_table = ll_table + ((size_t)1 << quantized);
90-
91-
dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll,
92-
ptr + header.ll_offset, header.gamma_offset - header.ll_offset,
93-
ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset,
94-
ll_table,
95-
gamma_table,
96-
num_non_leaf_nodes,
97-
num_leaf_nodes
98-
);
99-
extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized));
100-
}
101-
else
102-
{
103-
ll_data = reinterpret_cast<const float*>(ptr + header.ll_offset);
104-
gamma_data = reinterpret_cast<const float*>(ptr + header.gamma_offset);
105-
leaf_ll_data = ll_data + num_non_leaf_nodes;
106-
extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes);
107-
}
108-
109-
size_t htx_vocab_size = header.vocab_size;
110-
if (header.htx_offset)
111-
{
112-
htx_data = reinterpret_cast<const KeyType*>(ptr + header.htx_offset);
113-
htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1;
114-
extra_buf = toAlignedPtr(htx_data + header.vocab_size);
115-
}
116-
117-
if (!header.extra_buf_size)
118-
{
119-
extra_buf = nullptr;
120-
}
121-
122-
// restore node's data
123-
node_data = make_unique<MyNode[]>(num_non_leaf_nodes);
124-
all_value_data = make_unique<DiffType[]>(header.num_nodes - 1 + htx_vocab_size);
125-
value_data = &all_value_data[htx_vocab_size];
126-
std::fill(&all_value_data[0], value_data, 0);
127-
128-
size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0;
129-
Vector<std::array<size_t, 3>> key_ranges;
130-
for (size_t i = 0; i < header.num_nodes; ++i)
131-
{
132-
if (node_sizes[i])
133-
{
134-
auto& node = node_data[non_leaf_idx];
135-
if (!key_ranges.empty())
136-
{
137-
auto& back = key_ranges.back();
138-
value_data[back[1]] = non_leaf_idx - back[0];
139-
}
140-
node.num_nexts = node_sizes[i];
141-
node.next_offset = next_offset;
142-
node.ll = ll_data[non_leaf_idx];
143-
node.gamma = gamma_data[non_leaf_idx];
144-
next_offset += node_sizes[i];
145-
key_ranges.emplace_back(std::array<size_t, 3>{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) });
146-
non_leaf_idx++;
147-
}
148-
else
149-
{
150-
auto& back = key_ranges.back();
151-
reinterpret_cast<float&>(value_data[back[1]]) = leaf_ll_data[leaf_idx];
152-
back[1]++;
153-
while (key_ranges.back()[1] == key_ranges.back()[2])
154-
{
155-
key_ranges.pop_back();
156-
if (key_ranges.empty()) break;
157-
key_ranges.back()[1]++;
158-
}
159-
leaf_idx++;
160-
}
161-
}
162-
163-
for (size_t i = 0; i < node_data[0].num_nexts; ++i)
164-
{
165-
auto k = key_data[i];
166-
auto v = value_data[i];
167-
all_value_data[k] = v;
168-
}
169-
170-
Vector<uint8_t> tempBuf;
171-
for (size_t i = 0; i < non_leaf_idx; ++i)
172-
{
173-
auto& node = node_data[i];
174-
nst::prepare<arch>(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf);
175-
}
176-
177-
if (htx_data)
178-
{
179-
ptrdiff_t node = 0;
180-
progress(node, (KeyType)header.bos_id);
181-
unk_ll = getLL(node, (KeyType)header.unk_id);
182-
bos_node_idx = 0;
183-
progress(bos_node_idx, htx_data[(KeyType)header.bos_id]);
184-
}
185-
else
186-
{
187-
unk_ll = getLL(0, (KeyType)header.unk_id);
188-
bos_node_idx = 0;
189-
progress(bos_node_idx, (KeyType)header.bos_id);
190-
}
191-
192-
Deque<MyNode*> dq;
193-
for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front())
194-
{
195-
auto p = dq.front();
196-
for (size_t i = 0; i < p->num_nexts; ++i)
197-
{
198-
auto k = key_data[p->next_offset + i];
199-
auto v = value_data[p->next_offset + i];
200-
if (v <= 0) continue;
201-
auto* child = &p[v];
202-
child->lower = findLowerNode(p, k) - child;
203-
dq.emplace_back(child);
204-
}
205-
}
206-
}
207-
20810
template<ArchType arch, class KeyType, bool transposed, class DiffType>
20911
float KnLangModel<arch, KeyType, transposed, DiffType>::getLL(ptrdiff_t node_idx, KeyType next) const
21012
{

0 commit comments

Comments
 (0)