@@ -7,204 +7,6 @@ namespace kiwi
7
7
{
8
8
namespace lm
9
9
{
10
- template <ArchType arch, class KeyType , bool transposed, class DiffType >
11
- template <ptrdiff_t ...idx>
12
- void KnLangModel<arch, KeyType, transposed, DiffType>::dequantizeDispatch(
13
- tp::seq<idx...>,
14
- size_t bits,
15
- Vector<float >& restored_floats, Vector<float >& restored_leaf_ll,
16
- const char * llq_data, size_t llq_size,
17
- const char * gammaq_data, size_t gammaq_size,
18
- const float * ll_table,
19
- const float * gamma_table,
20
- size_t num_non_leaf_nodes,
21
- size_t num_leaf_nodes
22
- )
23
- {
24
- using Fn = void (*)(Vector<float >&, Vector<float >&,
25
- const char *, size_t ,
26
- const char *, size_t ,
27
- const float *,
28
- const float *,
29
- size_t ,
30
- size_t );
31
- static constexpr Fn table[] = {
32
- &dequantize<idx + 1 >...
33
- };
34
- return table[bits - 1 ](restored_floats, restored_leaf_ll,
35
- llq_data, llq_size,
36
- gammaq_data, gammaq_size,
37
- ll_table, gamma_table,
38
- num_non_leaf_nodes, num_leaf_nodes
39
- );
40
- }
41
-
42
- template <ArchType arch, class KeyType , bool transposed, class DiffType >
43
- KnLangModel<arch, KeyType, transposed, DiffType>::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move (mem) }
44
- {
45
- auto * ptr = reinterpret_cast <const char *>(base.get ());
46
- auto & header = getHeader ();
47
- const size_t quantized = header.quantized & 0x1F ;
48
- const bool compressed = header.quantized & 0x80 ;
49
-
50
- Vector<KeyType> d_node_size;
51
- auto * node_sizes = reinterpret_cast <const KeyType*>(ptr + header.node_offset );
52
- key_data = make_unique<KeyType[]>((header.ll_offset - header.key_offset ) / sizeof (KeyType));
53
- std::memcpy (&key_data[0 ], ptr + header.key_offset , header.ll_offset - header.key_offset );
54
- size_t num_leaf_nodes = 0 ;
55
- if (compressed)
56
- {
57
- d_node_size.resize (header.num_nodes );
58
- auto qc_header = reinterpret_cast <const uint8_t *>(ptr + header.node_offset );
59
- auto qc_body = reinterpret_cast <const size_t *>(qc_header + (header.num_nodes + 3 ) / 4 );
60
- QCode::template decode<8 >((uint16_t *)d_node_size.data (), qc_header, qc_body, 0 , header.num_nodes );
61
- node_sizes = d_node_size.data ();
62
- }
63
-
64
- for (size_t i = 0 ; i < header.num_nodes ; ++i)
65
- {
66
- if (node_sizes[i]) num_non_leaf_nodes++;
67
- else num_leaf_nodes++;
68
- }
69
-
70
- // restore ll & gamma data
71
- Vector<float > restored_leaf_ll, restored_floats;
72
- const float * ll_data = nullptr ;
73
- const float * gamma_data = nullptr ;
74
- const float * leaf_ll_data = nullptr ;
75
- if (quantized)
76
- {
77
- if (quantized > 16 )
78
- {
79
- throw std::runtime_error{ " 16+ bits quantization not supported." };
80
- }
81
-
82
- restored_floats.resize (num_non_leaf_nodes * 2 );
83
- restored_leaf_ll.resize (num_leaf_nodes);
84
- leaf_ll_data = restored_leaf_ll.data ();
85
- ll_data = &restored_floats[0 ];
86
- gamma_data = &restored_floats[num_non_leaf_nodes];
87
-
88
- const float * ll_table = reinterpret_cast <const float *>(ptr + header.qtable_offset );
89
- const float * gamma_table = ll_table + ((size_t )1 << quantized);
90
-
91
- dequantizeDispatch (tp::gen_seq<16 >{}, quantized, restored_floats, restored_leaf_ll,
92
- ptr + header.ll_offset , header.gamma_offset - header.ll_offset ,
93
- ptr + header.gamma_offset , header.qtable_offset - header.gamma_offset ,
94
- ll_table,
95
- gamma_table,
96
- num_non_leaf_nodes,
97
- num_leaf_nodes
98
- );
99
- extra_buf = toAlignedPtr (gamma_table + ((size_t )1 << quantized));
100
- }
101
- else
102
- {
103
- ll_data = reinterpret_cast <const float *>(ptr + header.ll_offset );
104
- gamma_data = reinterpret_cast <const float *>(ptr + header.gamma_offset );
105
- leaf_ll_data = ll_data + num_non_leaf_nodes;
106
- extra_buf = toAlignedPtr (gamma_data + num_non_leaf_nodes);
107
- }
108
-
109
- size_t htx_vocab_size = header.vocab_size ;
110
- if (header.htx_offset )
111
- {
112
- htx_data = reinterpret_cast <const KeyType*>(ptr + header.htx_offset );
113
- htx_vocab_size = *std::max_element (htx_data, htx_data + header.vocab_size ) + 1 ;
114
- extra_buf = toAlignedPtr (htx_data + header.vocab_size );
115
- }
116
-
117
- if (!header.extra_buf_size )
118
- {
119
- extra_buf = nullptr ;
120
- }
121
-
122
- // restore node's data
123
- node_data = make_unique<MyNode[]>(num_non_leaf_nodes);
124
- all_value_data = make_unique<DiffType[]>(header.num_nodes - 1 + htx_vocab_size);
125
- value_data = &all_value_data[htx_vocab_size];
126
- std::fill (&all_value_data[0 ], value_data, 0 );
127
-
128
- size_t non_leaf_idx = 0 , leaf_idx = 0 , next_offset = 0 ;
129
- Vector<std::array<size_t , 3 >> key_ranges;
130
- for (size_t i = 0 ; i < header.num_nodes ; ++i)
131
- {
132
- if (node_sizes[i])
133
- {
134
- auto & node = node_data[non_leaf_idx];
135
- if (!key_ranges.empty ())
136
- {
137
- auto & back = key_ranges.back ();
138
- value_data[back[1 ]] = non_leaf_idx - back[0 ];
139
- }
140
- node.num_nexts = node_sizes[i];
141
- node.next_offset = next_offset;
142
- node.ll = ll_data[non_leaf_idx];
143
- node.gamma = gamma_data[non_leaf_idx];
144
- next_offset += node_sizes[i];
145
- key_ranges.emplace_back (std::array<size_t , 3 >{ non_leaf_idx, (size_t )node.next_offset , (size_t )(node.next_offset + node.num_nexts ) });
146
- non_leaf_idx++;
147
- }
148
- else
149
- {
150
- auto & back = key_ranges.back ();
151
- reinterpret_cast <float &>(value_data[back[1 ]]) = leaf_ll_data[leaf_idx];
152
- back[1 ]++;
153
- while (key_ranges.back ()[1 ] == key_ranges.back ()[2 ])
154
- {
155
- key_ranges.pop_back ();
156
- if (key_ranges.empty ()) break ;
157
- key_ranges.back ()[1 ]++;
158
- }
159
- leaf_idx++;
160
- }
161
- }
162
-
163
- for (size_t i = 0 ; i < node_data[0 ].num_nexts ; ++i)
164
- {
165
- auto k = key_data[i];
166
- auto v = value_data[i];
167
- all_value_data[k] = v;
168
- }
169
-
170
- Vector<uint8_t > tempBuf;
171
- for (size_t i = 0 ; i < non_leaf_idx; ++i)
172
- {
173
- auto & node = node_data[i];
174
- nst::prepare<arch>(&key_data[node.next_offset ], &value_data[node.next_offset ], node.num_nexts , tempBuf);
175
- }
176
-
177
- if (htx_data)
178
- {
179
- ptrdiff_t node = 0 ;
180
- progress (node, (KeyType)header.bos_id );
181
- unk_ll = getLL (node, (KeyType)header.unk_id );
182
- bos_node_idx = 0 ;
183
- progress (bos_node_idx, htx_data[(KeyType)header.bos_id ]);
184
- }
185
- else
186
- {
187
- unk_ll = getLL (0 , (KeyType)header.unk_id );
188
- bos_node_idx = 0 ;
189
- progress (bos_node_idx, (KeyType)header.bos_id );
190
- }
191
-
192
- Deque<MyNode*> dq;
193
- for (dq.emplace_back (&node_data[0 ]); !dq.empty (); dq.pop_front ())
194
- {
195
- auto p = dq.front ();
196
- for (size_t i = 0 ; i < p->num_nexts ; ++i)
197
- {
198
- auto k = key_data[p->next_offset + i];
199
- auto v = value_data[p->next_offset + i];
200
- if (v <= 0 ) continue ;
201
- auto * child = &p[v];
202
- child->lower = findLowerNode (p, k) - child;
203
- dq.emplace_back (child);
204
- }
205
- }
206
- }
207
-
208
10
template <ArchType arch, class KeyType , bool transposed, class DiffType >
209
11
float KnLangModel<arch, KeyType, transposed, DiffType>::getLL(ptrdiff_t node_idx, KeyType next) const
210
12
{
0 commit comments