@@ -19,7 +19,6 @@ namespace kiwi
19
19
{
20
20
static constexpr size_t serialAlignment = 16 ;
21
21
22
-
23
22
using QCode = qe::QCode<0 , 2 , 8 , 16 >;
24
23
25
24
template <size_t bits>
@@ -98,11 +97,8 @@ namespace kiwi
98
97
std::unique_ptr<DiffType[]> all_value_data;
99
98
size_t num_non_leaf_nodes = 0 ;
100
99
DiffType* value_data = nullptr ;
101
- const float * ll_data = nullptr ;
102
- const float * gamma_data = nullptr ;
103
100
const KeyType* htx_data = nullptr ;
104
101
const void * extra_buf = nullptr ;
105
- Vector<float > restored_floats;
106
102
float unk_ll = 0 ;
107
103
ptrdiff_t bos_node_idx = 0 ;
108
104
@@ -193,7 +189,9 @@ namespace kiwi
193
189
}
194
190
195
191
// restore ll & gamma data
196
- Vector<float > restored_leaf_ll;
192
+ Vector<float > restored_leaf_ll, restored_floats;
193
+ const float * ll_data = nullptr ;
194
+ const float * gamma_data = nullptr ;
197
195
const float * leaf_ll_data = nullptr ;
198
196
if (quantized)
199
197
{
@@ -262,6 +260,8 @@ namespace kiwi
262
260
}
263
261
node.num_nexts = node_sizes[i];
264
262
node.next_offset = next_offset;
263
+ node.ll = ll_data[non_leaf_idx];
264
+ node.gamma = gamma_data[non_leaf_idx];
265
265
next_offset += node_sizes[i];
266
266
key_ranges.emplace_back (std::array<size_t , 3 >{ non_leaf_idx, (size_t )node.next_offset , (size_t )(node.next_offset + node.num_nexts ) });
267
267
non_leaf_idx++;
@@ -343,14 +343,14 @@ namespace kiwi
343
343
node->num_nexts , next, v
344
344
))
345
345
{
346
- return gamma_data[node_idx] + getLL (node_idx + node->lower , next);
346
+ return node-> gamma + getLL (node_idx + node->lower , next);
347
347
}
348
348
}
349
349
350
350
// non-leaf node
351
351
if (v > 0 )
352
352
{
353
- return ll_data [node_idx + v];
353
+ return node_data [node_idx + v]. ll ;
354
354
}
355
355
// leaf node
356
356
else
@@ -396,7 +396,7 @@ namespace kiwi
396
396
node->num_nexts , next, v
397
397
))
398
398
{
399
- acc += gamma_data[node_idx] ;
399
+ acc += node-> gamma ;
400
400
node_idx += node->lower ;
401
401
PREFETCH_T0 (&key_data[node_data[node_idx].next_offset ]);
402
402
continue ;
@@ -407,7 +407,7 @@ namespace kiwi
407
407
if (v > 0 )
408
408
{
409
409
node_idx += v;
410
- return acc + ll_data [node_idx];
410
+ return acc + node_data [node_idx]. ll ;
411
411
}
412
412
// leaf node
413
413
else
@@ -456,16 +456,6 @@ namespace kiwi
456
456
return bos_node_idx;
457
457
}
458
458
459
- const float * getLLBuf () const final
460
- {
461
- return ll_data;
462
- }
463
-
464
- const float * getGammaBuf () const final
465
- {
466
- return gamma_data;
467
- }
468
-
469
459
const void * getExtraBuf () const final
470
460
{
471
461
return extra_buf;
@@ -481,11 +471,6 @@ namespace kiwi
481
471
return num_non_leaf_nodes;
482
472
}
483
473
484
- size_t llSize () const final
485
- {
486
- return gamma_data - ll_data;
487
- }
488
-
489
474
std::vector<float > allNextLL (ptrdiff_t node_idx) const final
490
475
{
491
476
std::vector<float > ret (getHeader ().vocab_size , -INFINITY);
@@ -500,14 +485,14 @@ namespace kiwi
500
485
}
501
486
else
502
487
{
503
- ret[keys[i]] = ll_data [node_idx + values[i]];
488
+ ret[keys[i]] = node_data [node_idx + values[i]]. ll ;
504
489
}
505
490
}
506
491
507
492
float acc = 0 ;
508
493
while (node->lower )
509
494
{
510
- acc += gamma_data[ node - &node_data[ 0 ]] ;
495
+ acc += node-> gamma ;
511
496
node += node->lower ;
512
497
keys = &key_data[node->next_offset ];
513
498
values = &value_data[node->next_offset ];
@@ -520,7 +505,7 @@ namespace kiwi
520
505
}
521
506
else
522
507
{
523
- ret[keys[i]] = acc + ll_data[ node - &node_data[ 0 ] + values[i]];
508
+ ret[keys[i]] = acc + node[ values[i]]. ll ;
524
509
}
525
510
}
526
511
}
@@ -550,7 +535,7 @@ namespace kiwi
550
535
}
551
536
else
552
537
{
553
- ret[k] = ll_data [node_idx + v];
538
+ ret[k] = node_data [node_idx + v]. ll ;
554
539
}
555
540
556
541
if (htx_data)
@@ -590,7 +575,7 @@ namespace kiwi
590
575
float acc = 0 ;
591
576
while (node->lower )
592
577
{
593
- acc += gamma_data[ node - &node_data[ 0 ]] ;
578
+ acc += node-> gamma ;
594
579
node += node->lower ;
595
580
keys = &key_data[node->next_offset ];
596
581
values = &value_data[node->next_offset ];
@@ -605,7 +590,7 @@ namespace kiwi
605
590
}
606
591
else
607
592
{
608
- ret[k] = acc + ll_data[ node - &node_data[ 0 ] + v] ;
593
+ ret[k] = acc + node[v]. ll ;
609
594
}
610
595
611
596
if (htx_data)
@@ -667,15 +652,15 @@ namespace kiwi
667
652
}
668
653
else
669
654
{
670
- buf.emplace_back (ll_data [node_idx + values[i]], (KeyOut)keys[i]);
655
+ buf.emplace_back (node_data [node_idx + values[i]]. ll , (KeyOut)keys[i]);
671
656
}
672
657
}
673
658
std::make_heap (buf.begin (), buf.end ());
674
659
675
660
float acc = 0 ;
676
661
while (node->num_nexts < top_n && node->lower )
677
662
{
678
- acc += gamma_data[ node - &node_data[ 0 ]] ;
663
+ acc += node-> gamma ;
679
664
node += node->lower ;
680
665
keys = &key_data[node->next_offset ];
681
666
values = &value_data[node->next_offset ];
@@ -687,7 +672,7 @@ namespace kiwi
687
672
}
688
673
else
689
674
{
690
- buf.emplace_back (acc + ll_data[ node - &node_data[ 0 ] + values[i]], (KeyOut)keys[i]);
675
+ buf.emplace_back (acc + node[ values[i]]. ll , (KeyOut)keys[i]);
691
676
}
692
677
std::push_heap (buf.begin (), buf.end ());
693
678
}
0 commit comments