-
Notifications
You must be signed in to change notification settings - Fork 740
Expand file tree
/
Copy pathblock_attn.cc
More file actions
1144 lines (1099 loc) · 51.1 KB
/
block_attn.cc
File metadata and controls
1144 lines (1099 loc) · 51.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <blocks/core_attention_block.h>
#include <core/check.h>
#include <core/context.h>
#include <core/param.h>
#include <core/types.h>
#include <flash_api.h>
#include <infer_ops.h>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft_api.h>
#include "ops/pybind/cachekv_signal_thread_worker.h"
#include "ops/remote_cache_kv_ipc.h"
#include "ops/utility/env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
XPU_DECLARE_BOOL(use_pd_disaggregation_per_chunk, false);
namespace xftblock = baidu::xpu::xftblock;
template <typename TC, typename TS>
struct SplitRopeTypeTrait {
using E_Scale = TS;
using D_Scale = TS;
};
template <>
struct SplitRopeTypeTrait<bfloat16, bfloat16> {
using E_Scale = bfloat16;
using D_Scale = float;
};
template <>
struct SplitRopeTypeTrait<int8_t, bfloat16> {
using E_Scale = bfloat16;
using D_Scale = bfloat16;
};
/**
* qkv shape: [token_num, (num_heads + 2 * kv_num_heads) * head_dim]
* k_scales/v_scales value: 127 / max (type = TS)
* k_scales_inv/v_scales_inv value:
* 1. perchannel with zp: max / 127 (type = TS)
* 2. perchannel without zp: max (type = float)
**/
template <typename TX, typename TC, typename TS>
std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
const paddle::Tensor& len_info_cpu,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_kv_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_context_len_cpu,
const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu,
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_kv_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_context_len,
const paddle::Tensor& decoder_context_len_cache,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& prefix_len,
const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv,
const paddle::optional<paddle::Tensor>& v_scales_inv,
const paddle::optional<paddle::Tensor>& k_zeros,
const paddle::optional<paddle::Tensor>& v_zeros,
const paddle::optional<paddle::Tensor>& shift,
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const bool use_neox_rotary_style,
const bool rope_3d) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();
using XPU_XType = typename XPUTypeTrait<TX>::Type;
using XPU_CType = typename XPUTypeTrait<TC>::Type;
using XPU_SType = typename XPUTypeTrait<TS>::Type;
using E_Scale = typename SplitRopeTypeTrait<XPU_CType, XPU_SType>::E_Scale;
using D_Scale = typename SplitRopeTypeTrait<XPU_CType, XPU_SType>::D_Scale;
typedef TX data_t;
typedef TC cdata_t;
typedef TS sdata_t;
xftblock::DataType KV_BUF_TYPE = std::is_same<bfloat16, XPU_XType>::value
? xftblock::DataType::DT_BFLOAT16
: xftblock::DataType::DT_FLOAT16;
auto qkv_shape = qkv.dims();
auto cache_shape = key_cache.dims();
auto block_table_shape = block_tables.dims();
const int block_batch = block_table_shape[0];
const int max_block_per_seq = block_table_shape[1];
const int kv_num_heads = cache_shape[1];
const int block_size = cache_shape[2];
const int head_dim = cache_shape[3];
const int max_seq_len = block_size * max_block_per_seq;
const int token_num = qkv_shape[0];
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
const int hidden_dim = num_heads * head_dim;
int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1];
int total_enc_len = len_info_cpu.data<int32_t>()[2];
int total_dec_len = token_num - total_enc_len;
int max_enc_len = len_info_cpu.data<int32_t>()[3];
int max_kv_len = len_info_cpu.data<int32_t>()[4];
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];
int max_dec_len = len_info_cpu.data<int32_t>()[6];
int rope_max_seqlen = 0;
int rope_head_dim = 0;
if (rope_3d) {
PD_CHECK(rotary_embs.dims().size() == 6,
"rotary_embs dim size should be 6 in multi-modal model");
rope_max_seqlen = rotary_embs.dims()[3];
rope_head_dim = rotary_embs.dims()[5];
} else {
PD_CHECK(rotary_embs.dims().size() == 5,
"rotary_embs dim size should be 5 in language model");
rope_max_seqlen = rotary_embs.dims()[2];
rope_head_dim = rotary_embs.dims()[4];
}
std::string pos_emb_type;
if (use_neox_rotary_style) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
// vl model use this
pos_emb_type = "HALF_HEAD_DIM";
} else {
pos_emb_type = "NORMAL";
}
auto block_attn_out =
paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place());
// TODO(lizanz03): only support c8 zp per channel
bool is_cache_int8 = std::is_same<int8_t, XPU_CType>::value;
bool has_zp = k_zeros && v_zeros;
XPU_SType *quant_k_scale{nullptr}, *quant_v_scale{nullptr},
*quant_k_scale_inv_zp{nullptr}, *quant_v_scale_inv_zp{nullptr},
*quant_k_zp{nullptr}, *quant_v_zp{nullptr};
// maxptr for xfa
float *quant_k_scale_inv{nullptr}, *quant_v_scale_inv{nullptr};
XPU_XType *p_shift{nullptr}, *p_smooth{nullptr};
if (is_cache_int8) {
// only support c8 per channel
quant_k_scale = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(k_scales.get().data<sdata_t>()));
quant_v_scale = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(v_scales.get().data<sdata_t>()));
if (shift) {
p_shift = reinterpret_cast<XPU_XType*>(
const_cast<data_t*>(shift.get().data<data_t>()));
}
if (smooth) {
p_smooth = reinterpret_cast<XPU_XType*>(
const_cast<data_t*>(smooth.get().data<data_t>()));
}
if (has_zp) {
quant_k_scale_inv_zp = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(k_scales_inv.get().data<sdata_t>()));
quant_v_scale_inv_zp = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(v_scales_inv.get().data<sdata_t>()));
quant_k_zp = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(k_zeros.get().data<sdata_t>()));
quant_v_zp = reinterpret_cast<XPU_SType*>(
const_cast<sdata_t*>(v_zeros.get().data<sdata_t>()));
} else {
quant_k_scale_inv = reinterpret_cast<float*>(
const_cast<float*>(k_scales_inv.get().data<float>()));
quant_v_scale_inv = reinterpret_cast<float*>(
const_cast<float*>(v_scales_inv.get().data<float>()));
}
}
const float *q_norm_weight_data{nullptr}, *k_norm_weight_data{nullptr};
if (q_norm_weight) {
q_norm_weight_data = q_norm_weight.get().data<float>();
}
if (k_norm_weight) {
k_norm_weight_data = k_norm_weight.get().data<float>();
}
PD_CHECK(!(pos_emb_type == "NEOX" && q_norm_weight_data != nullptr),
"split_neox_cache_kv_encoder not support q/k norm weight");
int ret = 0;
if (enc_batch > 0) {
xftblock::TransformerParam param;
xftblock::TransformerVsl vsl;
param.batch_size = enc_batch;
param.head_num = num_heads;
param.kv_head_num = kv_num_heads;
param.head_dim = head_dim;
param.max_batch_size = block_batch;
param.max_seq_len = max_seq_len;
param.use_cache_per_channel =
is_cache_int8 && !has_zp; // only support c8 per channel
vsl.usual_lod_vp = {
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()),
enc_batch + 1,
const_cast<int32_t*>(encoder_kv_lod.data<int32_t>())};
vsl.slot_mapping_vp = {
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
enc_batch,
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())}; // real batch
param.max_valid_seqlen = max_enc_len;
param.max_kv_valid_seqlen = max_kv_len;
// setting for prefix cache
bool is_prefix_cache = prefix_block_num_per_seq > 0;
param.prefill_len = is_prefix_cache ? param.max_valid_seqlen : -1;
param.page_attn.block_size = block_size;
param.page_attn.max_num_blocks_per_seq = prefix_block_num_per_seq;
// prefix_block_tables is a subset of block_tables, which is used for
// prefix cache
xftblock::Tensor prefix_block_tables_tensor(
is_prefix_cache ? reinterpret_cast<void*>(const_cast<int32_t*>(
prefix_block_tables.data<int32_t>()))
: nullptr,
xftblock::DataType::DT_INT32,
{prefix_block_tables.dims()[0], prefix_block_num_per_seq});
param.page_attn.block_table = &prefix_block_tables_tensor;
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()),
enc_batch,
const_cast<int32_t*>(prefix_len.data<int32_t>())};
float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp && is_prefix_cache) {
fake_perhead_scale = RAII_GUARD.alloc<float>(param.kv_head_num);
// set fake_perhead_scale to ones
int ret = api::constant<float>(
xpu_ctx->x_context(), fake_perhead_scale, param.kv_head_num, 127.f);
PD_CHECK(ret == api::SUCCESS, "api::constant failed.");
}
// buf tensor
bool use_l3 = false;
xftblock::Tensor q_buf(
rt_guard, KV_BUF_TYPE, {total_enc_len, hidden_dim}, false, false);
xftblock::Tensor k_buf(
rt_guard->alloc( // src_data
KV_BUF_TYPE, // data_type
total_enc_len * kv_num_heads * head_dim, // numel
xftblock::AllocType::GM, // alloc_type
&use_l3), // use_l3
nullptr, // max_data
nullptr, // pc_scale
KV_BUF_TYPE, // type
{total_enc_len, kv_num_heads * head_dim});
xftblock::Tensor v_buf(
rt_guard->alloc( // src_data
KV_BUF_TYPE, // data_type
total_enc_len * kv_num_heads * head_dim, // numel
xftblock::AllocType::GM, // alloc_type
&use_l3), // use_l3
nullptr, // max_data
nullptr, // pc_scale
KV_BUF_TYPE, // type
{total_enc_len, kv_num_heads * head_dim});
// kv cache tensor
xftblock::Tensor key_cache_tensor(
reinterpret_cast<void*>(
const_cast<cdata_t*>(key_cache.data<cdata_t>())), // src_data
nullptr, // max_data
has_zp // pc_scale
? fake_perhead_scale
: quant_k_scale_inv,
is_cache_int8 // cache type
? xftblock::DataType::DT_INT8
: KV_BUF_TYPE,
{cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]});
xftblock::Tensor value_cache_tensor(
reinterpret_cast<void*>(
const_cast<cdata_t*>(value_cache.data<cdata_t>())), // src_data
nullptr, // max_data
has_zp // pc_scale
? fake_perhead_scale
: quant_v_scale_inv,
is_cache_int8 // cache type
? xftblock::DataType::DT_INT8
: KV_BUF_TYPE,
{cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]});
xftblock::Tensor encode_output(reinterpret_cast<void*>(const_cast<data_t*>(
block_attn_out.data<data_t>())),
KV_BUF_TYPE,
{total_enc_len, hidden_dim});
// rope + cache
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
nullptr, // intx_k_pc_scale
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()), // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
q_norm_weight_data,
k_norm_weight_data,
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
// pd split
if (FLAGS_fmt_write_cache_completed_signal) {
XPUEvent write_event = nullptr;
ret = xpu_event_create(&write_event);
PD_CHECK(ret == 0, "xpu_event_create write_event failed.");
ret = xpu_event_record(write_event, xctx.get_main_stream());
PD_CHECK(ret == 0, "xpu_event_record failed.");
PD_CHECK(cachekv_signal_thread_cpu,
"cachekv_signal_thread should not be nullptr");
auto worker = reinterpret_cast<CacheKvSignalThreadWorker*>(
cachekv_signal_thread_cpu.get().data<int64_t>()[0]);
PD_CHECK(worker != nullptr,
"cachekv_signal_thread should not be nullptr");
if (FLAGS_use_pd_disaggregation_per_chunk) {
worker->push_signal_task_per_query(write_event, nullptr);
} else {
// If use micro batch:
// micro_batch_0 do nothing.
// micro_batch_1 write kv signal.
if (kv_signal_data_cpu) {
worker->push_signal_task(
write_event,
reinterpret_cast<void*>((const_cast<int64_t*>(
kv_signal_data_cpu.get().data<int64_t>()))));
}
}
}
if (is_cache_int8 && has_zp && is_prefix_cache) {
int64_t q_head_num = param.head_num;
int64_t kv_head_num = param.kv_head_num;
// assume q_layout is BLHD, q = q * k_scales_inv
ret = api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
q_buf.data<XPU_XType>(),
quant_k_scale_inv_zp,
q_buf.data<XPU_XType>(),
{total_enc_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
}
// attn encode
if (is_prefix_cache) {
ret =
xftblock::xft_context_core_attenion_block<XPU_XType,
XPU_CType,
float>(&xctx,
&q_buf,
&key_cache_tensor,
&value_cache_tensor,
&encode_output,
param,
vsl);
} else {
ret = xftblock::
xft_context_core_attenion_block<XPU_XType, XPU_XType, float>(
&xctx, &q_buf, &k_buf, &v_buf, &encode_output, param, vsl);
}
PD_CHECK(ret == api::SUCCESS,
"xftblock::xft_context_core_attenion_block failed.");
if (is_cache_int8 && has_zp && is_prefix_cache) {
int64_t q_head_num = param.head_num;
int64_t kv_head_num = param.kv_head_num;
// out = (out - v_zeros) * v_scales_inv
ret = api::broadcast_sub<XPU_XType>(xpu_ctx->x_context(),
encode_output.data<XPU_XType>(),
quant_v_zp,
encode_output.data<XPU_XType>(),
{total_enc_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_sub failed.");
ret = api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
encode_output.data<XPU_XType>(),
quant_v_scale_inv_zp,
encode_output.data<XPU_XType>(),
{total_enc_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
}
if (p_shift != nullptr) {
ret = api::broadcast_add<XPU_XType>(xpu_ctx->x_context(),
p_shift,
encode_output.data<XPU_XType>(),
encode_output.data<XPU_XType>(),
{1, hidden_dim},
{total_enc_len, hidden_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add for shift failed.");
}
if (p_smooth != nullptr) {
ret = api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
p_smooth,
encode_output.data<XPU_XType>(),
encode_output.data<XPU_XType>(),
{1, hidden_dim},
{total_enc_len, hidden_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul for smooth failed.");
}
}
if (dec_batch > 0) {
xftblock::TransformerParam param;
xftblock::TransformerVsl vsl;
param.batch_size = dec_batch;
param.head_num = num_heads;
param.kv_head_num = kv_num_heads;
param.head_dim = head_dim;
param.max_batch_size = block_batch;
param.max_seq_len = max_seq_len;
param.use_page_attn = true;
xftblock::Tensor decode_output(
reinterpret_cast<void*>(
const_cast<data_t*>(block_attn_out.data<data_t>()) +
total_enc_len * hidden_dim),
KV_BUF_TYPE,
{total_dec_len, hidden_dim});
if (total_dec_len != dec_batch) {
bool Eq_len = (total_dec_len % dec_batch == 0);
// only support draft token num == 1, used in draft model
int q_len = Eq_len ? total_dec_len / dec_batch : 1;
xftblock::Tensor q_buf(
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
xftblock::Tensor k_buf(rt_guard,
KV_BUF_TYPE,
{total_dec_len, kv_num_heads * head_dim},
false,
false);
xftblock::Tensor v_buf(rt_guard,
KV_BUF_TYPE,
{total_dec_len, kv_num_heads * head_dim},
false,
false);
api::VectorParam<int32_t> decoder_context_len_vp = {
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t*>(
decoder_context_len
.data<int32_t>())}; // use for speculative_attention_decoder
// seq_len in MTP
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t*>(
decoder_context_len_cache
.data<int32_t>())}; // use for split rope enc as prefix cache
// len in MTP
api::VectorParam<int32_t> decoder_batch_map_vp = {
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t*>(
decoder_batch_map.data<int32_t>())}; // real batch
api::VectorParam<int32_t> decoder_seq_lod_vp = {
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
dec_batch + 1,
const_cast<int32_t*>(
decoder_seq_lod
.data<int32_t>())}; // use for split rope enc as lod in MTP
// rope + cache
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::
split_neox_cache_kv_encoder<XPU_XType, float, XPU_CType, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
nullptr, // intx_k_pc_scale
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
float,
XPU_CType,
int,
E_Scale>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
k_buf.data<XPU_XType>(),
v_buf.data<XPU_XType>(),
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens (prefix len)
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
nullptr, // k_cache_scale_inv - use for per head
nullptr, // v_cache_scale_inv - use for per head
quant_k_scale, // intx_k_pc_scale
quant_v_scale, // intx_v_pc_scale
quant_k_zp, // intx_k_pc_zero
quant_v_zp, // intx_v_pc_zero
q_norm_weight_data,
k_norm_weight_data,
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed.");
}
float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) {
int64_t q_head_num = param.head_num;
int64_t kv_head_num = param.kv_head_num;
fake_perhead_scale = RAII_GUARD.alloc<float>(kv_head_num);
// q = q * k_scales_inv
ret =
api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
q_buf.data<XPU_XType>(),
quant_k_scale_inv_zp,
q_buf.data<XPU_XType>(),
{total_dec_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
// set fake_perhead_scale to ones
ret = api::constant<float>(
xpu_ctx->x_context(), fake_perhead_scale, kv_head_num, 127.f);
PD_CHECK(ret == api::SUCCESS, "api::constant failed.");
}
XPU_XType* q_buf_ptr = q_buf.data<XPU_XType>();
XPU_XType* decode_output_ptr = decode_output.data<XPU_XType>();
const int* decoder_context_len_ptr =
decoder_context_len_cpu.data<int32_t>();
const int* decoder_context_len_cache_ptr =
decoder_context_len_cache_cpu.data<int32_t>();
std::vector<int> lody_vec(dec_batch + 1);
std::vector<int> offset_vec(dec_batch, 0);
std::vector<int> lod_ref_vec(dec_batch + 1, 0);
using TGEMM = std::
conditional_t<std::is_same_v<XPU_XType, XPU_CType>, float, int8_wo_t>;
constexpr int quant_mode = std::is_same_v<XPU_CType, int8_t> ? 3 : 0;
ret = baidu::xpu::xfa::speculative_attention_decoder<XPU_XType,
XPU_CType,
XPU_XType,
TGEMM,
TGEMM,
float,
int32_t,
quant_mode>(
xpu_ctx->x_context(),
decode_output_ptr, // out
q_buf_ptr, // q
nullptr, // k
nullptr, // v
reinterpret_cast<const XPU_CType*>(
key_cache.data<cdata_t>()), // k_cache
reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>()), // v_cache
reinterpret_cast<const int32_t*>(
block_tables.data<int32_t>()), // block_tables
decoder_context_len_vp, // seq_lengths
decoder_batch_map_vp, // valid_batch
param.max_batch_size, // batch_num
max_dec_len, // qlen
max_seq_len, // max_seq_len
param.head_num, // head_num
param.head_dim, // head_dim
param.kv_head_num, // kv_head_num
nullptr, // attn_mask
1.0f /
std::sqrt(static_cast<float>(param.head_dim)), // scale 【check】
block_size, // block_size
max_block_per_seq, // max_blocks_per_seq
-1, // max_window_size
nullptr, // q_maxptr
has_zp // k_cache_maxptr
? fake_perhead_scale
: quant_k_scale_inv,
has_zp // v_cache_maxptr
? fake_perhead_scale
: quant_v_scale_inv,
nullptr, // o_maxptr
param.head_dim, // vo_head_dim
decoder_seq_lod_vp); // qlod
PD_CHECK(ret == api::SUCCESS,
"xfa::speculative_attention_decoder failed.");
} else {
vsl.usual_lod_vp = {
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t*>(decoder_context_len.data<int32_t>())};
vsl.slot_mapping_vp = {
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch,
const_cast<int32_t*>(
decoder_batch_map.data<int32_t>())}; // real batch
xftblock::Tensor q_buf(
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
xftblock::Tensor block_table_tensor(
reinterpret_cast<void*>(
const_cast<int32_t*>(block_tables.data<int32_t>())),
xftblock::DataType::DT_INT32,
{block_table_shape[0], block_table_shape[1]});
// normal setting
param.use_cache_per_channel =
is_cache_int8 && !has_zp; // only support c8 per channel
param.prefill_len = -1;
param.page_attn.block_size = block_size;
param.page_attn.max_context_len = max_seq_len;
param.page_attn.max_num_blocks_per_seq = max_block_per_seq;
param.page_attn.block_table = &block_table_tensor;
// rope + cache
int ret = 0;
if (pos_emb_type == "NEOX") {
ret = infer_ops::split_neox_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
float,
XPU_CType,
D_Scale,
int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_XType*>(qkv.data<data_t>()) +
total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv
reinterpret_cast<const float*>(
rotary_embs.data<float>()), // rotary_pos_emb
reinterpret_cast<const int*>(
block_tables.data<int>()), // block_table
q_buf.data<XPU_XType>(),
nullptr,
nullptr,
const_cast<XPU_CType*>(
reinterpret_cast<const XPU_CType*>(key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size = rotary_embs.dims()[1] = 1
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
"BLHD",
"HLD",
pos_emb_type,
reinterpret_cast<D_Scale*>(quant_k_scale), // k_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
q_norm_weight_data,
k_norm_weight_data,
is_cache_int8, // bool b_c8_pc
rope_3d);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
}
float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp) {
int64_t q_head_num = param.head_num;
int64_t kv_head_num = param.kv_head_num;
fake_perhead_scale = RAII_GUARD.alloc<float>(kv_head_num);
// q = q * k_scales_inv
ret =
api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
q_buf.data<XPU_XType>(),
quant_k_scale_inv_zp,
q_buf.data<XPU_XType>(),
{total_dec_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
// set fake_perhead_scale to ones
ret = api::constant<float>(
xpu_ctx->x_context(), fake_perhead_scale, kv_head_num, 127.f);
PD_CHECK(ret == api::SUCCESS, "api::constant failed.");
}
// kv cache tensor
xftblock::Tensor key_cache_tensor(
reinterpret_cast<void*>(
const_cast<cdata_t*>(key_cache.data<cdata_t>())), // src_data
nullptr, // max_data
has_zp // pc_scale
? fake_perhead_scale
: quant_k_scale_inv,
is_cache_int8 // cache type
? xftblock::DataType::DT_INT8
: KV_BUF_TYPE,
{cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]});
xftblock::Tensor value_cache_tensor(
reinterpret_cast<void*>(
const_cast<cdata_t*>(value_cache.data<cdata_t>())), // src_data
nullptr, // max_data
has_zp // pc_scale
? fake_perhead_scale
: quant_v_scale_inv,
is_cache_int8 // cache type
? xftblock::DataType::DT_INT8
: KV_BUF_TYPE,
{cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]});
// attn decode
ret = xftblock::xft_decoder_core_attenion_block<
XPU_XType,
XPU_CType,
XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high
// precision
&xctx,
&q_buf,
&key_cache_tensor,
&value_cache_tensor,
&decode_output,
param,
vsl);
PD_CHECK(ret == api::SUCCESS,
"xftblock::xft_decoder_core_attenion_block failed.");
}
if (is_cache_int8 && has_zp) {
int64_t q_head_num = param.head_num;
int64_t kv_head_num = param.kv_head_num;
// out = (out - v_zeros) * v_scales_inv
if (quant_v_zp) {
ret =
api::broadcast_sub<XPU_XType>(xpu_ctx->x_context(),
decode_output.data<XPU_XType>(),
quant_v_zp,
decode_output.data<XPU_XType>(),
{total_dec_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_sub failed.");
}
ret = api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
decode_output.data<XPU_XType>(),
quant_v_scale_inv_zp,
decode_output.data<XPU_XType>(),
{total_dec_len,
kv_head_num,
q_head_num / kv_head_num,
param.head_dim},
{1, kv_head_num, 1, param.head_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
}
if (p_shift != nullptr) {
ret = api::broadcast_add<XPU_XType>(xpu_ctx->x_context(),
p_shift,
decode_output.data<XPU_XType>(),
decode_output.data<XPU_XType>(),
{1, hidden_dim},
{total_dec_len, hidden_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add for shift failed.");
}
if (p_smooth != nullptr) {
ret = api::broadcast_mul<XPU_XType>(xpu_ctx->x_context(),
p_smooth,
decode_output.data<XPU_XType>(),
decode_output.data<XPU_XType>(),
{1, hidden_dim},
{total_dec_len, hidden_dim});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul for smooth failed.");
}
}
return {block_attn_out};
}
std::vector<paddle::Tensor> BlockAttnFused(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
const paddle::Tensor& len_info_cpu,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_kv_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_context_len_cpu,
const paddle::Tensor& decoder_context_len_cache_cpu,