-
Notifications
You must be signed in to change notification settings - Fork 740
Expand file tree
/
Copy pathbeam_search_softmax.cu
More file actions
1563 lines (1386 loc) · 54.8 KB
/
beam_search_softmax.cu
File metadata and controls
1563 lines (1386 loc) · 54.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/types.h>
#if !defined(_WIN32)
#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>
#endif
#include <algorithm>
#include "helper.h"
#include "stdint.h"
#include "cccl_compat.h" // CCCL 3.0 compatibility
#define FLT_MAX 1e38
static constexpr int kBlockSizeForSmallBeamWidth = 256;
static constexpr int kMaxVocabPartForStage1FastKernel = 128;
#define CASE_K(K) \
case K: \
invokeTopKSoftMaxLauncher<T, 2 * K, GROUP>( \
params, beam_group_idx, stream); \
break
#define DISPATCH_COMPUTE_PARTS_K(K) \
case K: \
ComputeVocParts<T, 2 * K>(params); \
break
template <typename T>
struct BeamSearchParams {
// Scalar values
int batch_size{0};
int beam_width{0};
int beam_group_size{0};
int beam_group_idx{0};
int vocab_size{0};
int dec_stride{0};
int max_seq_len{0};
int end_ids_len{0};
bool fuse_softmax{true};
bool early_stop{false};
int voc_parts{0};
bool use_fast_kernel{true};
int max_smem_per_block{0};
T *logits{nullptr};
const int *step_ids{nullptr}; // [BS * BM, 1]
const int *seq_lens{nullptr}; // [BS * BM, 1]
const int *max_dec_lens{nullptr};
const int *end_ids{nullptr};
const T *cum_scores{nullptr};
const int *block_tables{nullptr};
const int *beam_cache_ids{nullptr};
const float *length_penalty{nullptr}; // [BS, 1]
const float *diversity_penalty{nullptr}; // [BS, 1]
bool *stop_flags{nullptr}; // [BS, 1]
int *cache_ids_out{nullptr}; // [BS * BM, max_dec_len]
bool *beam_finished{nullptr}; // [BS * BM, 1]
int *block_tables_out{nullptr}; // [BS * BM, max_seq_len]
T *cum_scores_out{nullptr}; // [BS * BM, 1]
int *beam_hyps_out{nullptr}; // [BS * BM, max_dec_len]
T *beam_hyps_score_out{nullptr}; // [BS * BM, 1]
// func out
int *next_tokens{nullptr};
int *parent_ids{nullptr};
// workspace
int *tmp_ids{nullptr};
T *tmp_vals{nullptr};
T *tmp_buffer{nullptr};
};
template <typename T,
typename U,
typename = std::enable_if_t<std::is_integral<T>::value>,
typename = std::enable_if_t<std::is_integral<U>::value>>
auto constexpr ceilDiv(T numerator, U denominator) {
return (numerator + denominator - 1) / denominator;
}
__device__ bool is_in_end(const int id, const int *end_ids, int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob,
int length,
float length_penalty) {
// score = log(prob) / (length)^length_penalty.
if (length_penalty == 0.0f || length == 1) {
return log_prob;
}
return log_prob / static_cast<T>(powf(length, length_penalty));
}
// <<<batch_size, beam_group_size>>>
template <typename T, int K>
__global__ void apply_group_diversity_penalty(BeamSearchParams<T> params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int vocab_size) {
const int beam_group_size = K / 2;
const int batch_idx = blockIdx.x;
const int beam_group_sub_idx = threadIdx.x;
const bool *beam_finished = params.beam_finished + batch_idx * beam_width;
T *logtis = params.logits + batch_idx * beam_width * vocab_size +
beam_group_idx * beam_group_size * vocab_size +
beam_group_sub_idx * vocab_size;
int *next_tokens = params.next_tokens + batch_idx * beam_width;
// apply previous group token ids penalty
#pragma unroll
for (int token_idx = 0; token_idx < beam_group_idx * beam_group_size;
++token_idx) {
const bool finished = beam_finished[token_idx];
if (!finished) {
const int token_id = next_tokens[token_idx];
logtis[token_id] -= params.diversity_penalty[batch_idx];
}
}
}
struct DySoftMaxStruct {
float logit;
float score;
};
__device__ __forceinline__ DySoftMaxStruct
reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) {
bool a_bigger = (a.logit > b.logit);
DySoftMaxStruct bigger_m = a_bigger ? a : b;
DySoftMaxStruct smaller_m = a_bigger ? b : a;
DySoftMaxStruct res;
res.score =
bigger_m.score + smaller_m.score * expf(smaller_m.logit - bigger_m.logit);
res.logit = bigger_m.logit;
return res;
}
template <typename T>
struct BeamHypothesis {
T score;
int *seq;
int seq_len;
__device__ __forceinline__ void init(int *_seq,
T _score,
const int _max_seq_len) {
seq = _seq;
score = _score;
seq_len = _max_seq_len;
}
};
template <typename T, int K>
struct BeamHypothesesTopK {
BeamHypothesis<T> hyps[K];
int max_dec_len;
__device__ __forceinline__ void init(int *_beam_hyps,
T *_beam_hyps_score,
const int _max_dec_len) {
max_dec_len = _max_dec_len;
for (int i = 0; i < K; i++) {
// 使用默认构造函数创建默认的 BeamHypothesis 对象
hyps[i].init(
_beam_hyps + i * _max_dec_len, _beam_hyps_score[i], _max_dec_len);
}
}
__device__ void insert(const int *token_ids,
int step,
int cur_token_id,
T score) {
if (score > get_worst_score()) {
for (int i = 0; i < step; i++) {
hyps[K - 1].seq[i] = token_ids[i];
}
hyps[K - 1].seq[step] = cur_token_id;
hyps[K - 1].score = score;
for (int k = K - 2; k >= 0; --k) {
if (hyps[k + 1].score > hyps[k].score) {
T tmp_score = hyps[k].score;
hyps[k].score = hyps[k + 1].score;
hyps[k + 1].score = tmp_score;
int tmp_val;
for (int i = 0;
i <= step && (hyps[k + 1].seq[i] > 0 || hyps[k].seq[i] > 0);
i++) {
tmp_val = hyps[k + 1].seq[i];
hyps[k + 1].seq[i] = hyps[k].seq[i];
hyps[k].seq[i] = tmp_val;
}
}
}
}
}
__device__ __forceinline__ T get_worst_score() { return hyps[K - 1].score; }
};
template <typename T, int K>
struct TopK {
int ids[K];
T vals[K];
int parent_ids[K];
__device__ __forceinline__ void insert(T elem, int elem_id) {
if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
vals[K - 1] = elem;
ids[K - 1] = elem_id;
}
for (int k = K - 2; k >= 0; --k) {
if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
T tmp_val = vals[k];
int tmp_id = ids[k];
vals[k] = vals[k + 1];
ids[k] = ids[k + 1];
vals[k + 1] = tmp_val;
ids[k + 1] = tmp_id;
}
}
}
__device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) {
if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
vals[K - 1] = elem;
ids[K - 1] = elem_id;
parent_ids[K - 1] = parent_id;
}
for (int k = K - 2; k >= 0; --k) {
if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
T tmp_val = vals[k];
int tmp_id = ids[k];
int parent_id2 = parent_ids[k];
vals[k] = vals[k + 1];
ids[k] = ids[k + 1];
parent_ids[k] = parent_ids[k + 1];
vals[k + 1] = tmp_val;
ids[k + 1] = tmp_id;
parent_ids[k + 1] = parent_id2;
}
}
}
};
template <typename T, int K>
__device__ __forceinline__ TopK<T, K> reduce_topk_op(const TopK<T, K> &a,
const TopK<T, K> &b) {
TopK<T, K> res = a;
for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]);
return res;
}
template <typename T, int K>
struct TopKSoftMax {
DySoftMaxStruct softmax_md;
TopK<T, K> topk;
};
template <typename T, int K>
__device__ __forceinline__ TopKSoftMax<T, K> reduce_topk_softmax_op(
const TopKSoftMax<T, K> &a, const TopKSoftMax<T, K> &b) {
TopKSoftMax<T, K> res;
// max_logit in block
res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md);
res.topk = reduce_topk_op(a.topk, b.topk);
return res;
}
struct __align__(8) MD {
float m;
float d;
};
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
bool const isABigger = a.m > b.m;
MD const bigger = isABigger ? a : b;
MD const smaller = isABigger ? b : a;
MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)};
return res;
}
template <typename T, int K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
void beam_search_softmax_topk_stage1_fast(const T *logits,
float *tmp_buffer,
const int *end_ids,
const bool *beam_finished,
const int *seq_lens,
int beam_width,
int beam_group_idx,
int vocab_size,
int vocab_chunk_size) {
constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
const int beam_group_size = K / 2;
const int tid = threadIdx.x;
const int group_beam_batch_id = blockIdx.x;
const int batch_id = group_beam_batch_id / beam_group_size;
const int beam_group_sub_id = group_beam_batch_id % beam_group_size;
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_id;
const int seq_len = seq_lens[beam_batch_id];
const bool finished = beam_finished[beam_batch_id];
if (seq_len < 0 || finished) {
return;
}
const int section_start = vocab_chunk_size * blockIdx.y;
const int section_end =
std::min(section_start + vocab_chunk_size, vocab_size);
const int valid_smem_length = section_end - section_start;
T const MAX_T_VAL = 1e38;
// Load element from logits to smemLogProbs, doing reduce_md and argmax
// meanwhile Each thread is responsible for `vocab_chunk_size /
// THREADBLOCK_SIZE` elements
extern __shared__ char smem[];
T *smemLogProbs = reinterpret_cast<T *>(smem);
MD partial_md{-MAX_T_VAL, 0.0f};
using KVPair = cub::KeyValuePair<int, T>;
KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
fd_cub_compat::ArgMax argmax;
T const *local_logits = logits + beam_batch_id * vocab_size;
#pragma unroll 1
for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) {
T const val = local_logits[i];
const int smem_index = i - section_start;
smemLogProbs[smem_index] = val;
MD new_elem_md{val, 1.0F};
partial_md = reduce_md_op(partial_md, new_elem_md);
KVPair new_elem_topk{smem_index, val};
topKVPairPartial = argmax(topKVPairPartial, new_elem_topk);
}
__syncthreads();
// Search the top 2K elements among `vocab_chunk_size` elements of this
// ThreadBlock and write into smemOutput
__shared__ float smemOutput[PACKED_TOP_KMD_SIZE];
__shared__ int threadToUpdate;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
__shared__ union {
typename BlockReduceTopK::TempStorage topk;
typename BlockReduceMD::TempStorage md;
} smemReduceBuffer;
for (int i = 0; i < 2 * beam_group_size; ++i) {
// Pop the element with largest value to "smemOutput" per iteration
KVPair topKVPair =
BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
if (tid == 0) {
// const int index = beam_batch_id * vocab_size + section_start +
const int index = section_start + topKVPair.key;
reinterpret_cast<int *>(smemOutput)[i] = index;
smemOutput[K + i] = topKVPair.value;
smemLogProbs[topKVPair.key] =
-MAX_T_VAL; // pollute the value of the popped element
threadToUpdate = topKVPair.key % THREADBLOCK_SIZE;
}
__syncthreads();
if (tid == threadToUpdate && i < 2 * beam_group_size - 1) {
// The thread popped the element need to update its topKVPairPartial
// No need to do this in the last iteration
topKVPairPartial.key = vocab_size - 1;
topKVPairPartial.value = -MAX_T_VAL;
for (int index = tid; index < valid_smem_length;
index += THREADBLOCK_SIZE) {
topKVPairPartial =
argmax(topKVPairPartial, {index, smemLogProbs[index]});
}
}
}
// Do reduce_md among the top 2K elements in the smemOutput and write into
// tail of smemOutput
auto reduce_md_func = [](const MD &a, const MD &b) {
return reduce_md_op(a, b);
};
MD total_md =
BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
if (tid == 0) {
smemOutput[2 * K] = total_md.d;
smemOutput[2 * K + 1] = total_md.m;
}
__syncthreads();
// Write the smemOutput into tmp_buffer
float *local_temp_buffer =
tmp_buffer + group_beam_batch_id * PACKED_TOP_KMD_SIZE * gridDim.y +
blockIdx.y * PACKED_TOP_KMD_SIZE;
#pragma unroll
for (int i = tid; i < PACKED_TOP_KMD_SIZE; i += THREADBLOCK_SIZE) {
local_temp_buffer[i] = smemOutput[i];
}
}
// <<<(batch_size * beam_group_size, voc_parts), 128>>>
template <typename T, int K, int THREADBLOCK_SIZE, int PACKED_TOP_KMD_SIZE>
__global__ void beam_search_softmax_topk_stage1(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const bool fuse_softmax) {
const int thread_id = threadIdx.x;
const int beam_group_size = K / 2;
const int batch_id = blockIdx.x / beam_group_size;
const int beam_group_sub_idx = blockIdx.x % beam_group_size;
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_idx;
const bool finish = params.beam_finished[beam_batch_id];
const int seq_len = params.seq_lens[beam_batch_id];
// for dybatch
if (seq_len < 0 || finish) {
return;
}
// 2 * K + 2
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
const T MAX_T_VAL = FLT_MAX;
const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local;
section_end = (section_end > vocab_size) ? vocab_size : section_end;
T *logits = params.logits + beam_batch_id * vocab_size;
if (fuse_softmax) {
typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopKSoftMax<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.topk.ids[i] = -1;
partial.topk.vals[i] = -MAX_T_VAL;
}
partial.softmax_md.logit = -MAX_T_VAL;
partial.softmax_md.score = 0.0F;
// process voc_parts
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end;
elem_id += THREADBLOCK_SIZE) {
T elem = logits[elem_id];
DySoftMaxStruct new_elem{elem, 1.0F};
partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem);
partial.topk.insert(elem, elem_id);
}
// === old_beam_search strategy ===
// }
// reduce voc_parts
TopKSoftMax<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
if (thread_id == 0) {
for (int i = 0; i < K; i++) {
reinterpret_cast<int *>(buf_s)[i] = total.topk.ids[i];
buf_s[K + i] = total.topk.vals[i];
}
buf_s[2 * K] = total.softmax_md.score;
buf_s[2 * K + 1] = total.softmax_md.logit;
}
} else {
typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -MAX_T_VAL;
}
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end;
elem_id += THREADBLOCK_SIZE) {
T elem = logits[elem_id];
partial.insert(elem, elem_id);
}
TopK<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
if (thread_id == 0) {
for (int i = 0; i < K; i++) {
reinterpret_cast<int *>(buf_s)[i] = total.ids[i];
buf_s[K + i] = total.vals[i];
}
}
}
__syncthreads();
// write all the voc_parts results to tmp_buffer
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE;
elem_id += THREADBLOCK_SIZE) {
params.tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y +
blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] =
buf_s[elem_id];
}
}
template <typename T, int K, int THREADBLOCK_SIZE, bool IS_FAST_KERNEL>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void beam_search_softmax_topk_stage2_fast(
int *__restrict tmp_ids,
T *__restrict tmp_vals,
float *__restrict tmp_buffer,
const float *__restrict cum_scores,
const bool *__restrict beam_finished,
const int *__restrict seq_lens,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const int voc_parts) {
constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
constexpr int beam_group_size = K / 2;
const int group_beam_batch_id = blockIdx.x;
const int beam_group_sub_id = blockIdx.x % beam_group_size;
const int batch_size = group_beam_batch_id / beam_group_size;
const int beam_batch_id = batch_size * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_id;
if (seq_lens[beam_batch_id] < 0 || beam_finished[beam_batch_id]) {
return;
}
const int tid = threadIdx.x;
T const MAX_T_VAL = FLT_MAX;
using KVPair = cub::KeyValuePair<int, T>;
using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
__shared__ KVPair buf_smem_kv[K];
__shared__ union {
typename BlockReduceTopK::TempStorage topk;
typename BlockReduceMD::TempStorage md;
} smemReduceBuffer;
fd_cub_compat::ArgMax argmax;
MD partial_md{-MAX_T_VAL, 0.0f};
KVPair topKVPair{vocab_size - 1, -MAX_T_VAL};
auto reduce_md_func = [](const MD &a, const MD &b) {
return reduce_md_op(a, b);
};
// Load and unpack into registers through smem
float *localTempBuffer =
tmp_buffer + PACKED_TOP_KMD_SIZE * group_beam_batch_id * voc_parts;
if constexpr (IS_FAST_KERNEL) { // Use share memory instead of global memory
extern __shared__ char smem[];
float *smemVal = reinterpret_cast<float *>(smem);
for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
smemVal[idx] = localTempBuffer[idx];
}
localTempBuffer = smemVal;
__syncthreads();
}
// Find the top 2K across all voc_parts
for (int k = 0; k < K; ++k) {
KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
// Only threads responsible for a chunk will do the computation
if (tid < voc_parts) {
for (int i = 0; i < K; ++i) {
const int current_index = tid * PACKED_TOP_KMD_SIZE + i;
T topValue = localTempBuffer[current_index + K];
topKVPairPartial = argmax(topKVPairPartial, {current_index, topValue});
}
}
KVPair topKVPair =
BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
__syncthreads();
if (tid == 0) {
// Store kv pairs in shared mem buffer
int temp_offset = topKVPair.key;
int global_offset = reinterpret_cast<int *>(localTempBuffer)[temp_offset];
topKVPair.key = global_offset;
buf_smem_kv[k] = topKVPair;
// Invalidate the maximum value within the chunk
reinterpret_cast<int *>(localTempBuffer)[temp_offset] =
vocab_size - 1; // id in share memory
localTempBuffer[temp_offset + K] = -MAX_T_VAL; // value in share memory
}
__syncthreads();
}
// Extract and reduce MD values across the chunks
if (tid < voc_parts) {
partial_md.d = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K];
partial_md.m = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K + 1];
}
__syncthreads();
MD total_md =
BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
if (tid == 0) {
float d_total_log = logf(total_md.d);
for (int i = 0; i < K; ++i) {
float val =
static_cast<float>(buf_smem_kv[i].value) - total_md.m - d_total_log;
tmp_ids[group_beam_batch_id * K + i] = buf_smem_kv[i].key;
tmp_vals[group_beam_batch_id * K + i] = val + cum_scores[beam_batch_id];
}
}
}
#define BEAM_STAGE2_KERNEL(N_VOCAB_PART, IS_FAST_KERNEL) \
do { \
if (IS_FAST_KERNEL && nShareMemory >= (48 << 10)) { \
cudaFuncSetAttribute( \
beam_search_softmax_topk_stage2_fast<T, \
K, \
N_VOCAB_PART, \
IS_FAST_KERNEL>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
nShareMemory); \
} \
beam_search_softmax_topk_stage2_fast<T, K, N_VOCAB_PART, IS_FAST_KERNEL> \
<<<batch_size * beam_group_size, \
N_VOCAB_PART, \
IS_FAST_KERNEL * nShareMemory, \
stream>>>(params->tmp_ids, \
params->tmp_vals, \
params->tmp_buffer, \
params->cum_scores, \
params->beam_finished, \
params->seq_lens, \
beam_width, \
beam_group_idx, \
vocab_size, \
voc_parts); \
} while (0); \
return;
template <typename T, int K>
__inline__ void beamSearchSoftmaxTopkStage2FastKernelLauncher(
BeamSearchParams<T> *params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const int voc_parts,
const int max_smem_per_block,
cudaStream_t stream) {
constexpr int beam_group_size = K / 2;
size_t const nShareMemory = sizeof(float) * voc_parts * (2 * K + 2) +
sizeof(cub::KeyValuePair<int, T>) * K;
if (nShareMemory < max_smem_per_block) { // IS_FAST_KERNEL must be a
// compilation-time constant
if (voc_parts <= 32) {
BEAM_STAGE2_KERNEL(32, true)
}
if (voc_parts <= 64) {
BEAM_STAGE2_KERNEL(64, true)
}
BEAM_STAGE2_KERNEL(128, true)
// No larger branch since voc_parts <= nMaxVocabPartForStage1FastKernel
}
BEAM_STAGE2_KERNEL(128, false)
}
template <typename T, int K, int THREADBLOCK_SIZE>
__global__ void beam_search_softmax_topk_stage2(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax) {
const int thread_id = threadIdx.x;
const int beam_group_size = K / 2;
const int batch_id = blockIdx.x / beam_group_size;
const int beam_group_sub_idx = blockIdx.x % beam_group_size;
// int vector_id = blockIdx.x; // batch beam index.
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_idx;
const int group_beam_batch_id = blockIdx.x;
// const int vector_id = blockIdx.x;
const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size;
// for dybatch
const int seq_len = params.seq_lens[beam_batch_id];
const bool finish = params.beam_finished[beam_batch_id];
int *tmp_ids = params.tmp_ids + group_beam_batch_id * K;
float *tmp_vals = params.tmp_vals + group_beam_batch_id * K;
float *tmp_buffer = params.tmp_buffer;
const T *cum_scores = params.cum_scores + beam_batch_id;
if (seq_len < 0 || finish) {
return;
}
const T MAX_T_VAL = FLT_MAX;
extern __shared__ char buf_s_[];
float *buf_s = reinterpret_cast<float *>(buf_s_);
// 当前 batch beam 的所有 voc
tmp_buffer += group_beam_batch_id * PACKED_TOP_KMD_SIZE * voc_parts;
if (fuse_softmax) {
typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopKSoftMax<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.topk.ids[i] = -1;
partial.topk.vals[i] = -MAX_T_VAL;
}
partial.softmax_md.logit = -MAX_T_VAL;
partial.softmax_md.score = 0.0F;
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
buf_s[idx] = tmp_buffer[idx];
}
__syncthreads();
if (threadIdx.x < voc_parts) {
float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; i++) {
partial.topk.ids[i] = reinterpret_cast<int *>(b_s)[i];
partial.topk.vals[i] = b_s[K + i];
}
partial.softmax_md.score = b_s[2 * K];
partial.softmax_md.logit = b_s[2 * K + 1];
}
__syncthreads();
TopKSoftMax<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
if (thread_id == 0) {
// tmp_ids += group_beam_batch_id * K;
// tmp_vals += group_beam_batch_id * K;
float d_total_log = logf(total.softmax_md.score);
for (int i = 0; i < K; ++i) {
// float val = expf((float)total.topk.vals[i] - total.softmax_md.logit -
// d_total_log);
float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log;
tmp_ids[i] = total.topk.ids[i];
tmp_vals[i] = val + params.cum_scores[beam_batch_id];
}
}
} else {
typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -MAX_T_VAL;
}
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
buf_s[idx] = tmp_buffer[idx];
}
__syncthreads();
if (threadIdx.x < voc_parts) {
float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; i++) {
partial.ids[i] = reinterpret_cast<int *>(b_s)[i];
partial.vals[i] = b_s[K + i];
}
}
__syncthreads();
TopK<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
if (thread_id == 0) {
tmp_ids += group_beam_batch_id * K;
tmp_vals += group_beam_batch_id * K;
for (int i = 0; i < K; ++i) {
float val = total.vals[i];
tmp_ids[i] = total.ids[i];
tmp_vals[i] = val + params.cum_scores[beam_batch_id];
}
}
}
}
template <typename T, int K>
void invokeBeamSearchSoftmaxTopKStage2(BeamSearchParams<T> *params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax,
gpuStream_t stream) {
int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float);
const int beam_group_size = K / 2;
if (voc_parts <= 32) {
beam_search_softmax_topk_stage2<T, K, 32>
<<<batch_size * beam_group_size, 32, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 64) {
beam_search_softmax_topk_stage2<T, K, 64>
<<<batch_size * beam_group_size, 64, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 128) {
beam_search_softmax_topk_stage2<T, K, 128>
<<<batch_size * beam_group_size, 128, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 256) {
beam_search_softmax_topk_stage2<T, K, 256>
<<<batch_size * beam_group_size, 256, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
}
template <typename T, int K>
__global__ void update_beam_finished_early_stop(const T *beam_hyps_score_out,
bool *beam_finished) {
if (threadIdx.x == 0) {
int batch_idx = blockIdx.x;
const T *cur_beam_hyps_score = beam_hyps_score_out + batch_idx * K;
bool *cur_beam_finished = beam_finished + batch_idx * K;
if (cur_beam_hyps_score[K - 1] > -1e8) {
for (int i = 0; i < K; i++) {
cur_beam_finished[i] = true;
}
}
}
}
// <<<batch_size>>>
template <typename T, int K, int THREADBLOCK_SIZE, bool GROUP>
__global__ void batch_topk(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int dec_stride) {
const bool early_stop = params.early_stop;
const int thread_id = threadIdx.x;
const int batch_id = blockIdx.x;
// int block_id = blockIdx.x; // bs
const int beam_group_size = K / 2;
const int beam_group_start_id =
batch_id * beam_width + beam_group_idx * beam_group_size;
bool *beam_finished = params.beam_finished + beam_group_start_id;
const int *step_ids = params.step_ids + beam_group_start_id;
int *next_tokens = params.next_tokens + beam_group_start_id;
float *cum_scores_out = params.cum_scores_out + beam_group_start_id;
int *parent_ids = params.parent_ids + beam_group_start_id;
float *beam_hyps_score_out = params.beam_hyps_score_out + beam_group_start_id;
const bool finish = beam_finished[0];
const int step_id = step_ids[0];
const int seq_len = params.seq_lens[beam_group_start_id];
const int max_dec_len = params.max_dec_lens[beam_group_start_id];
const bool last_dec_step = (step_id + 1 == max_dec_len);
if (thread_id == 0 && seq_len > 0 && !finish) {
TopK<T, K> partial;
BeamHypothesesTopK<T, K / 2> beam_hyps;
beam_hyps.init(params.beam_hyps_out + beam_group_start_id * dec_stride,
params.beam_hyps_score_out + beam_group_start_id,
dec_stride);
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -FLT_MAX;
partial.parent_ids[i] = -1;
}
int index = batch_id * beam_group_size * K;
if (step_id == 0) {
for (int i = 0; i < K; i++) {
float score_now = apply_length_penalty(params.tmp_vals[index + i],
step_id + 1,
params.length_penalty[batch_id]);
if (!GROUP) {
score_now -=
params.diversity_penalty[batch_id] * static_cast<float>(i + 1);
}
partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
}
} else {
for (int i = 0; i < beam_group_size * K; i++) {
float score_now = apply_length_penalty(params.tmp_vals[index + i],
step_id + 1,
params.length_penalty[batch_id]);
if (!GROUP) {
score_now -= params.diversity_penalty[batch_id] *
static_cast<float>(i % K + 1);
}
partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
}
}
if (partial.vals[0] < beam_hyps.hyps[beam_group_size - 1].score) {
for (int i = 0; i < beam_group_size; i++) {
beam_finished[i] = true;
}
return;
}
int next_step_num = 0;
for (int i = 0; i < K && next_step_num < beam_group_size; i++) {
int parent_id = partial.parent_ids[i];