Skip to content

Commit 2e56b16

Browse files
committed
Move ll and gamma into Node structure
1 parent 63efcad commit 2e56b16

File tree

2 files changed

+19
-36
lines changed

2 files changed

+19
-36
lines changed

include/kiwi/Knlm.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace kiwi
3131
KeyType num_nexts = 0;
3232
DiffType lower = 0;
3333
uint32_t next_offset = 0;
34+
float ll = 0, gamma = 0;
3435
};
3536

3637
class KnLangModelBase
@@ -56,9 +57,6 @@ namespace kiwi
5657
virtual ptrdiff_t getLowerNode(ptrdiff_t node_idx) const = 0;
5758

5859
virtual size_t nonLeafNodeSize() const = 0;
59-
virtual size_t llSize() const = 0;
60-
virtual const float* getLLBuf() const = 0;
61-
virtual const float* getGammaBuf() const = 0;
6260
virtual const void* getExtraBuf() const = 0;
6361

6462
static std::unique_ptr<KnLangModelBase> create(utils::MemoryObject&& mem, ArchType archType = ArchType::none);

src/Knlm.hpp

+18-33
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ namespace kiwi
1919
{
2020
static constexpr size_t serialAlignment = 16;
2121

22-
2322
using QCode = qe::QCode<0, 2, 8, 16>;
2423

2524
template<size_t bits>
@@ -98,11 +97,8 @@ namespace kiwi
9897
std::unique_ptr<DiffType[]> all_value_data;
9998
size_t num_non_leaf_nodes = 0;
10099
DiffType* value_data = nullptr;
101-
const float* ll_data = nullptr;
102-
const float* gamma_data = nullptr;
103100
const KeyType* htx_data = nullptr;
104101
const void* extra_buf = nullptr;
105-
Vector<float> restored_floats;
106102
float unk_ll = 0;
107103
ptrdiff_t bos_node_idx = 0;
108104

@@ -193,7 +189,9 @@ namespace kiwi
193189
}
194190

195191
// restore ll & gamma data
196-
Vector<float> restored_leaf_ll;
192+
Vector<float> restored_leaf_ll, restored_floats;
193+
const float* ll_data = nullptr;
194+
const float* gamma_data = nullptr;
197195
const float* leaf_ll_data = nullptr;
198196
if (quantized)
199197
{
@@ -262,6 +260,8 @@ namespace kiwi
262260
}
263261
node.num_nexts = node_sizes[i];
264262
node.next_offset = next_offset;
263+
node.ll = ll_data[non_leaf_idx];
264+
node.gamma = gamma_data[non_leaf_idx];
265265
next_offset += node_sizes[i];
266266
key_ranges.emplace_back(std::array<size_t, 3>{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) });
267267
non_leaf_idx++;
@@ -343,14 +343,14 @@ namespace kiwi
343343
node->num_nexts, next, v
344344
))
345345
{
346-
return gamma_data[node_idx] + getLL(node_idx + node->lower, next);
346+
return node->gamma + getLL(node_idx + node->lower, next);
347347
}
348348
}
349349

350350
// non-leaf node
351351
if (v > 0)
352352
{
353-
return ll_data[node_idx + v];
353+
return node_data[node_idx + v].ll;
354354
}
355355
// leaf node
356356
else
@@ -396,7 +396,7 @@ namespace kiwi
396396
node->num_nexts, next, v
397397
))
398398
{
399-
acc += gamma_data[node_idx];
399+
acc += node->gamma;
400400
node_idx += node->lower;
401401
PREFETCH_T0(&key_data[node_data[node_idx].next_offset]);
402402
continue;
@@ -407,7 +407,7 @@ namespace kiwi
407407
if (v > 0)
408408
{
409409
node_idx += v;
410-
return acc + ll_data[node_idx];
410+
return acc + node_data[node_idx].ll;
411411
}
412412
// leaf node
413413
else
@@ -456,16 +456,6 @@ namespace kiwi
456456
return bos_node_idx;
457457
}
458458

459-
const float* getLLBuf() const final
460-
{
461-
return ll_data;
462-
}
463-
464-
const float* getGammaBuf() const final
465-
{
466-
return gamma_data;
467-
}
468-
469459
const void* getExtraBuf() const final
470460
{
471461
return extra_buf;
@@ -481,11 +471,6 @@ namespace kiwi
481471
return num_non_leaf_nodes;
482472
}
483473

484-
size_t llSize() const final
485-
{
486-
return gamma_data - ll_data;
487-
}
488-
489474
std::vector<float> allNextLL(ptrdiff_t node_idx) const final
490475
{
491476
std::vector<float> ret(getHeader().vocab_size, -INFINITY);
@@ -500,14 +485,14 @@ namespace kiwi
500485
}
501486
else
502487
{
503-
ret[keys[i]] = ll_data[node_idx + values[i]];
488+
ret[keys[i]] = node_data[node_idx + values[i]].ll;
504489
}
505490
}
506491

507492
float acc = 0;
508493
while (node->lower)
509494
{
510-
acc += gamma_data[node - &node_data[0]];
495+
acc += node->gamma;
511496
node += node->lower;
512497
keys = &key_data[node->next_offset];
513498
values = &value_data[node->next_offset];
@@ -520,7 +505,7 @@ namespace kiwi
520505
}
521506
else
522507
{
523-
ret[keys[i]] = acc + ll_data[node - &node_data[0] + values[i]];
508+
ret[keys[i]] = acc + node[values[i]].ll;
524509
}
525510
}
526511
}
@@ -550,7 +535,7 @@ namespace kiwi
550535
}
551536
else
552537
{
553-
ret[k] = ll_data[node_idx + v];
538+
ret[k] = node_data[node_idx + v].ll;
554539
}
555540

556541
if (htx_data)
@@ -590,7 +575,7 @@ namespace kiwi
590575
float acc = 0;
591576
while (node->lower)
592577
{
593-
acc += gamma_data[node - &node_data[0]];
578+
acc += node->gamma;
594579
node += node->lower;
595580
keys = &key_data[node->next_offset];
596581
values = &value_data[node->next_offset];
@@ -605,7 +590,7 @@ namespace kiwi
605590
}
606591
else
607592
{
608-
ret[k] = acc + ll_data[node - &node_data[0] + v];
593+
ret[k] = acc + node[v].ll;
609594
}
610595

611596
if (htx_data)
@@ -667,15 +652,15 @@ namespace kiwi
667652
}
668653
else
669654
{
670-
buf.emplace_back(ll_data[node_idx + values[i]], (KeyOut)keys[i]);
655+
buf.emplace_back(node_data[node_idx + values[i]].ll, (KeyOut)keys[i]);
671656
}
672657
}
673658
std::make_heap(buf.begin(), buf.end());
674659

675660
float acc = 0;
676661
while (node->num_nexts < top_n && node->lower)
677662
{
678-
acc += gamma_data[node - &node_data[0]];
663+
acc += node->gamma;
679664
node += node->lower;
680665
keys = &key_data[node->next_offset];
681666
values = &value_data[node->next_offset];
@@ -687,7 +672,7 @@ namespace kiwi
687672
}
688673
else
689674
{
690-
buf.emplace_back(acc + ll_data[node - &node_data[0] + values[i]], (KeyOut)keys[i]);
675+
buf.emplace_back(acc + node[values[i]].ll, (KeyOut)keys[i]);
691676
}
692677
std::push_heap(buf.begin(), buf.end());
693678
}

0 commit comments

Comments
 (0)