forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattention.py
More file actions
2668 lines (2391 loc) · 115 KB
/
attention.py
File metadata and controls
2668 lines (2391 loc) · 115 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import functools
import math
import os
import weakref
from typing import List, Optional, Union, cast
import torch
from torch import nn
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import (get_sm_version, is_sm_100f, nvtx_range,
nvtx_range_debug)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend import (AttentionInputType, AttentionMetadata,
FlashInferAttentionMetadata, TrtllmAttention,
TrtllmAttentionMetadata)
from ..attention_backend.interface import (AttentionBackend, AttentionMask,
CustomAttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask)
from ..attention_backend.sparse.dsa import (
DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view)
from ..attention_backend.utils import create_attention, get_attention_backend
from ..distributed import (AllReduceParams, HelixAllToAllNative, alltoall_helix,
cp_allgather, reducescatter)
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_torch_compiling, maybe_compiled_cat,
maybe_compiled_copy_)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding
# Import FlashMLA sparse attention kernel
try:
from tensorrt_llm.flash_mla import flash_mla_sparse_fwd
except ImportError:
flash_mla_sparse_fwd = None
def extract_extra_attrs(layer_idx: str, attn_type: str):
assert attn_type in ["mla", "attn"], "Invalid attention type"
extra_attrs = get_model_extra_attrs()
assert extra_attrs is not None, "Model extra attrs is not set"
metadata_ref = extra_attrs.get("attention_metadata", None)
assert metadata_ref is not None, "Attention metadata is not set"
metadata = metadata_ref()
if attn_type == "mla":
assert isinstance(
metadata,
TrtllmAttentionMetadata,
)
else:
assert isinstance(
metadata,
FlashInferAttentionMetadata,
) or isinstance(
metadata,
TrtllmAttentionMetadata,
)
attn_layers = extra_attrs.get(attn_type + "_layers", None)
assert attn_layers is not None, "Attention layer is not registered"
attn_layer_ref = attn_layers.get(layer_idx, None)
assert attn_layer_ref is not None, f"Cannot find attention layer for layer {layer_idx}"
attn_layer = attn_layer_ref()
if attn_type == "mla":
assert isinstance(
attn_layer,
MLA), "MLA layer must be a subclass of MLA or an instance of MLA"
elif attn_type == "attn":
assert isinstance(
attn_layer, Attention
), "Attention layer must be a subclass of Attention or an instance of Attention"
return metadata, attn_layer
def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
layer_idx: str) -> List[torch.Tensor]:
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
return attn_layer.create_output(q, metadata, attention_mask)
@torch.library.custom_op("trtllm::create_attn_outputs", mutates_args=())
def create_attn_outputs(q: torch.Tensor, attention_mask: str,
layer_idx: str) -> List[torch.Tensor]:
return create_attn_outputs_impl(q, attention_mask, layer_idx)
@create_attn_outputs.register_fake
def _(q, attention_mask, layer_idx):
return create_attn_outputs_impl(q, attention_mask, layer_idx)
@torch.library.custom_op("trtllm::attn_custom_op_inplace",
mutates_args=("output", "output_sf"))
def attn_custom_op_inplace(
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attention_mask: str,
mrope_rotary_cos_sin: Optional[torch.Tensor],
mrope_position_deltas: Optional[torch.Tensor],
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
attention_sinks: Optional[torch.Tensor],
layer_idx: str,
output: torch.Tensor,
output_sf: Optional[torch.Tensor],
) -> None:
metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
mask = PredefinedAttentionMask(
attention_mask
) if attention_mask != CustomAttentionMask.CUSTOM else CustomAttentionMask(
attention_mask)
# NVFP4 output cannot be supported by torch compile for TRTLLM backend.
attn_layer._attn_impl(q,
k,
v,
metadata,
mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
output=output,
output_sf=output_sf,
attention_sinks=attention_sinks)
def _helix_post_process(
partial_o: torch.Tensor,
softmax_stats: torch.Tensor,
mapping: Mapping,
num_heads_tp_cp: int,
value_dim: int,
aux_stream: Optional[torch.cuda.Stream] = None,
ln_events: Optional[list] = None,
) -> torch.Tensor:
"""Helix CP post-processing: all-to-all exchange and combine partial
attention outputs across CP ranks.
This is shared by both MHA (Attention) and MLA modules. The only
dimension that differs between the two callers is *value_dim*
(``head_dim`` for MHA, ``kv_lora_rank`` for MLA).
When *aux_stream* and *ln_events* are provided the two
``.contiguous()`` calls in the FIFO-v1 path are overlapped on
separate CUDA streams for better performance.
"""
if mapping.cp_config.get("use_nccl_for_alltoall", True):
# NCCL-based implementation using alltoall_helix.
chunks = []
for t in [partial_o, softmax_stats]:
t = t.transpose(1, 0).contiguous()
chunks.extend(torch.split(t, t.shape[0] // mapping.cp_size))
gathered = alltoall_helix(chunks, mapping.cp_group)
gathered = [t.transpose(1, 2).contiguous() for t in gathered]
return torch.ops.trtllm.helix_post_process(gathered[0], gathered[1],
1.0)
else:
# FIFO-based implementation using MNNVL workspace.
helix = HelixAllToAllNative.get(mapping)
num_tokens = partial_o.shape[0]
cp_size = mapping.cp_size
fifo_version = mapping.cp_config.get("fifo_version", 2)
if fifo_version == 1:
reshape_o = lambda: partial_o.view(
num_tokens, cp_size, num_heads_tp_cp, value_dim).transpose(
1, 2).contiguous()
reshape_s = lambda: softmax_stats.view(
num_tokens, cp_size, num_heads_tp_cp, 2).transpose(
1, 2).contiguous()
if aux_stream is not None and ln_events is not None:
partial_o, softmax_stats = maybe_execute_in_parallel(
reshape_o,
reshape_s,
ln_events[0],
ln_events[1],
aux_stream,
)
else:
partial_o = reshape_o()
softmax_stats = reshape_s()
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
return torch.ops.trtllm.helix_post_process_native(
partial_o_out, softmax_stats_out, 1.0, 2)
else:
partial_o = partial_o.view(num_tokens, cp_size,
num_heads_tp_cp * value_dim)
softmax_stats = softmax_stats.view(num_tokens, cp_size,
num_heads_tp_cp * 2)
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
gathered_o = partial_o_out.view(num_tokens, cp_size,
num_heads_tp_cp, value_dim)
gathered_stats = softmax_stats_out.view(num_tokens, cp_size,
num_heads_tp_cp, 2)
return torch.ops.trtllm.helix_post_process_native(
gathered_o, gathered_stats, 1.0, 1)
def _helix_cp_pad(tensor: torch.Tensor, num_tokens: int,
cp_size: int) -> tuple[torch.Tensor, int]:
"""Pad tensor along dim-0 so its length is divisible by cp_size."""
chunk_size = math.ceil(num_tokens / cp_size)
padded_size = chunk_size * cp_size
if num_tokens < padded_size:
tensor = torch.nn.functional.pad(tensor,
(0, 0, 0, padded_size - num_tokens),
mode="constant",
value=0)
return tensor, chunk_size
def _helix_cp_allgather_input(hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mapping: Mapping, layer_idx: int) -> torch.Tensor:
"""AllGather hidden states from CP group for layers after the first.
The first layer already has the full input from the embedding.
Subsequent layers need to undo the previous layer's reduce-scatter.
"""
if (mapping.has_cp_helix() and mapping.enable_attention_dp
and layer_idx > 0):
hidden_states = cp_allgather(hidden_states, mapping, dim=0)
hidden_states = hidden_states[:attn_metadata.num_tokens]
return hidden_states
def _helix_cp_output_projection(
o_proj: Linear,
attn_output: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams],
mapping: Mapping,
mapping_o: Mapping,
layer_idx: int,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
"""Apply output projection with reduce-scatter when Helix CP+DP is active.
Reduce-scatter sums partial sums across the CP group and scatters the
result so each CP rank processes a distinct token chunk through the MLP.
Falls back to the standard AllReduce path otherwise.
"""
if mapping.has_cp_helix() and mapping.enable_attention_dp:
attn_output = o_proj(
attn_output,
all_reduce_params=AllReduceParams(enable_allreduce=False),
lora_params=lora_params,
layer_idx=layer_idx)
attn_output, _ = _helix_cp_pad(attn_output, attn_metadata.num_tokens,
mapping.cp_size)
attn_output = reducescatter(attn_output, mapping_o, dim=0)
else:
attn_output = o_proj(attn_output,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
layer_idx=layer_idx)
return attn_output
def maybe_slice_for_helix_cp(tensor: torch.Tensor,
attn_metadata: AttentionMetadata,
mapping_with_cp: Optional[Mapping],
layer_idx: int) -> torch.Tensor:
"""Slice a tensor to this CP rank's chunk after reduce-scatter.
For the first decoder layer, the residual comes from the embedding and
has not been through a prior reduce-scatter. This function slices it
so it aligns with the reduce-scattered attention output. For
subsequent layers the residual already has the correct size, so this
is a no-op.
Call this in the decoder layer on the residual *after* the attention
forward, so that Attention/MLA forward signatures stay unchanged.
"""
if (mapping_with_cp is not None and mapping_with_cp.has_cp_helix()
and mapping_with_cp.enable_attention_dp and layer_idx == 0):
tensor, chunk_size = _helix_cp_pad(tensor, attn_metadata.num_tokens,
mapping_with_cp.cp_size)
start = mapping_with_cp.cp_rank * chunk_size
tensor = tensor[start:start + chunk_size]
return tensor
def maybe_allgather_for_helix_cp(
hidden_states: torch.Tensor, attn_metadata: AttentionMetadata,
mapping_with_cp: Optional[Mapping]) -> torch.Tensor:
"""Restore full token count after the last layer's reduce-scatter.
With Helix CP + Attention DP, each decoder layer's reduce-scatter
leaves each CP rank with only its chunk of tokens. This function
performs an AllGather across the CP group so that the LM head (and
final norm) see every token.
Should be called at the end of the model's ``forward()`` method,
after the decoder layer loop.
"""
if (mapping_with_cp is not None and mapping_with_cp.has_cp_helix()
and mapping_with_cp.enable_attention_dp):
hidden_states = cp_allgather(hidden_states, mapping_with_cp, dim=0)
hidden_states = hidden_states[:attn_metadata.num_tokens]
return hidden_states
class Attention(nn.Module):
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
rope_fusion: Optional[bool] = None,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
q_scaling: float = 1.0,
attention_chunk_size: Optional[int] = None,
disable_deep_gemm: bool = False,
attn_output_gate: Optional[bool] = None,
use_custom_cublas_mm: bool = False,
reduce_output: bool = True,
mapping_with_cp: Optional[Mapping] = None,
):
"""
Initialize the Attention module.
Args:
hidden_size (int): The size of the hidden dimension.
num_attention_heads (int): The number of attention heads.
num_key_value_heads (int): The number of key value heads.
max_position_embeddings (int): The maximum position embeddings.
bias (bool): Whether to use bias in the linear layers.
pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
layer_idx (Optional[int]): The layer index.
dtype (torch.dtype): The data type.
dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
config (Optional[ModelConfig]): The model configuration.
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8).
attn_output_gate (Optional[bool]): Determines whether to use an output gate in the attention Op. If False, the decision is automatically handled by the attention backend based on its capabilities.
mapping_with_cp (Optional[Mapping]): Override mapping with CP configuration.
"""
super().__init__()
self.layer_idx = layer_idx
self.layer_idx_str = str(layer_idx)
self.register_to_config = False
# We only register TRTLLM attention layers to config.
if config is not None:
if "attn_layers" not in config.extra_attrs:
config.extra_attrs["attn_layers"] = {}
suffix = 0
# Makes sure there is no duplicate attention layer identifier.
while self.layer_idx_str in config.extra_attrs["attn_layers"]:
self.layer_idx_str = str(layer_idx) + f"_{suffix}"
suffix += 1
config.extra_attrs["attn_layers"][self.layer_idx_str] = weakref.ref(
self)
self.register_to_config = True
config = config or ModelConfig()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = getattr(config.pretrained_config, 'head_dim', None)
if not isinstance(self.head_dim, int):
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.pos_embd_params = pos_embd_params
self.dense_bias = dense_bias
self.q_scaling = q_scaling
self.attn_output_gate = attn_output_gate
if self.attn_output_gate:
logger.info_once("using attn output gate!", key="attn_output_gate")
# [Chunked Attention]
# Chunked attention is applied to context requests only. Chunked attention will be
# applied when this field is specified and mMaskType == CAUSAL.
#
# In chunked attention, we break context requests into chunks of a specified size. Tokens can only
# attend to tokens in the same chunk. So, for example, if the chunk size is 3, we might have a mask
# that looks like this:
#
# 1 0 0 0 0 0
# 1 1 0 0 0 0
# 1 1 1 0 0 0
# 0 0 0 1 0 0
# 0 0 0 1 1 0
# 0 0 0 1 1 1
self.attention_chunk_size = attention_chunk_size
if dense_bias is None:
self.dense_bias = bias
# tensor parallel
if mapping_with_cp is not None:
logger.warning_once(
"[Attention::__init__] Overriding mapping with CP detected.",
key="attention_init_mapping_with_cp")
self.mapping = mapping_with_cp
else:
self.mapping = config.mapping
tp_size = self.mapping.tp_size
pp_size = self.mapping.pp_size
cp_size = self.mapping.cp_size
dp_size = 1
if self.mapping.enable_attention_dp:
dp_size = tp_size
tp_size = 1
if self.mapping.cp_size > 1:
assert self.mapping.has_cp_helix(
), f"CP type must be HELIX for Attention, but got {self.mapping.cp_config['cp_type']}."
if dp_size == 1 and cp_size == 1:
mapping = self.mapping
else:
mapping = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size * dp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.tp_size = tp_size
self.cp_size = cp_size
self.tp_rank = mapping.tp_rank
assert self.num_heads % (tp_size * cp_size) == 0
self.num_heads = self.num_heads // tp_size
self.num_heads_tp_cp = self.num_heads // cp_size
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
1) // tp_size
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
qkv_shard_indices_mapping = {
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
"k":
(self.q_size * (2 if self.attn_output_gate else 1), self.kv_size),
"v":
(self.q_size * (2 if self.attn_output_gate else 1) + self.kv_size,
self.kv_size),
}
self.qkv_proj = Linear(
self.hidden_size,
tp_size * self.q_size * (2 if self.attn_output_gate else 1) +
2 * tp_size * self.kv_size,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
# For Helix CP, combine TP and CP for the output projection so each
# rank's o_proj input is num_heads_tp_cp * head_dim.
if dp_size == 1 and cp_size == 1:
mapping_o = self.mapping
else:
mapping_o = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size * dp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.mapping_o = mapping_o
self.o_proj = Linear(
tp_size * self.q_size,
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
mapping=mapping_o,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
attn_cls = get_attention_backend(
self.attn_backend,
sparse_attn_config=config.sparse_attention_config)
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
# handles them as a single fused operation.
self.splitted_qkv_lora = LoraLayer([
LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K,
LoraModuleType.ATTENTION_V
], [self.q_size, self.kv_size, self.kv_size])
self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV],
[self.q_size + 2 * self.kv_size])
# Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.
self.rope_fusion = rope_fusion
if config.sparse_attention_config is not None:
# Log sparse attention configuration once
algo = config.sparse_attention_config.algorithm
cfg_dump = config.sparse_attention_config.model_dump(
exclude_none=True)
logger.info_once(f"Using sparse attention: {algo} {cfg_dump}",
key="sparse_attention_config")
if config.sparse_attention_config.algorithm == "rocket":
logger.warning_once("disable rope_fusion for RocketKV.",
key="disable_rope_fusion_for_rocketkv")
self.rope_fusion = False
if self.rope_fusion and not attn_cls.support_fused_rope():
logger.warning_once(
"rope_fusion is true but the attention backend does not support it. Will disable rope_fusion.",
key="disable_rope_fusion_for_non_supported_backend")
self.rope_fusion = False
# If rope_fusion is not specified, enable if the attention backend supports it.
if self.rope_fusion is None:
self.rope_fusion = attn_cls.support_fused_rope()
self.rotary_emb = None
if not self.rope_fusion and self.pos_embd_params is not None:
if self.pos_embd_params.type.is_mrope():
self.rotary_emb = MRotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=self.pos_embd_params.is_neox,
mrope_section=self.pos_embd_params.mrope_section,
mrope_interleaved=self.pos_embd_params.mrope_interleaved)
else:
self.rotary_emb = RotaryEmbedding(
self.pos_embd_params.rope,
head_dim=self.head_dim,
is_neox=self.pos_embd_params.is_neox,
)
self.attn = create_attention(
self.attn_backend,
self.layer_idx,
self.num_heads,
self.head_dim,
self.num_key_value_heads,
pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
quant_config=self.quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
attention_chunk_size=self.attention_chunk_size,
sparse_attention_config=config.sparse_attention_config,
)
self.support_fused_qkv = self.attn.support_fused_qkv()
if not config.skip_create_weights_in_init:
self.create_weights()
def create_weights(self):
# self.attn has no weights but has states that are related to quant_config,
# which could be modified after __init__
self.attn.update_quant_config(self.quant_config)
self.o_proj.create_weights()
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
or self.o_proj.has_fp8_block_scales
or self.o_proj.has_fp8_rowwise
or self.o_proj.has_w4a8_nvfp4_fp8)
def split_qkv(self, q, k=None, v=None):
if k is None and v is None:
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return q, k, v
def convert_qkv(self, q, k, v):
if k is None and v is None and not self.support_fused_qkv:
q, k, v = self.split_qkv(q)
elif k is not None and v is not None and self.support_fused_qkv:
qkv = torch.concat([q, k, v], dim=-1)
q, k, v = qkv, None, None
return q, k, v
def _use_quantize_output(self):
has_awq_pre_quant_scale = hasattr(
self.o_proj,
'pre_quant_scale') and self.o_proj.pre_quant_scale is not None
return self.has_quant_scale and not self.attn_output_gate and not has_awq_pre_quant_scale
def create_output(self, q: torch.Tensor, attn_metadata: AttentionMetadata,
mask_type: str):
# Attention is treated as mixed request by default.
return self.attn.create_output(
q,
is_quantize_output=self._use_quantize_output(),
metadata=attn_metadata,
attention_mask=mask_type,
is_gen_only=False)
def _helix_post_process(self, partial_o: torch.Tensor,
softmax_stats: torch.Tensor) -> torch.Tensor:
"""Helix CP post-processing: all-to-all exchange and combine partial
attention outputs across CP ranks."""
return _helix_post_process(partial_o, softmax_stats, self.mapping,
self.num_heads_tp_cp, self.head_dim)
def _attn_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
mrope_rotary_cos_sin: Optional[torch.Tensor],
mrope_position_deltas: Optional[torch.Tensor],
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
output: Optional[torch.Tensor] = None,
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
has_lora: bool = False,
):
num_tokens = attn_metadata.num_tokens
q = q[:num_tokens, :]
if k is not None:
k = k[:num_tokens, :]
if v is not None:
v = v[:num_tokens, :]
mrope_config = None
if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None:
mrope_config = dict()
if mrope_rotary_cos_sin is not None:
mrope_config["mrope_rotary_cos_sin"] = mrope_rotary_cos_sin
if mrope_position_deltas is not None:
mrope_config["mrope_position_deltas"] = mrope_position_deltas
# Helix CP generation path: get partial outputs with softmax stats,
# then exchange and combine across CP ranks.
# NOTE: The helix post-process combine step works on unquantized
# (BF16/FP16) partial outputs and softmax stats from each rank.
# We intentionally skip passing out_scale/out_scale_sf to FMHA here
# so it produces BF16 output. After combining, the downstream o_proj
# linear layer handles quantization (FP8/NVFP4) in its apply() method.
if self.mapping.has_cp_helix() and attn_metadata.num_contexts == 0:
assert output is None, (
"Helix produces BF16 partial outputs which may not match a pre-allocated FP8/NVFP4 buffer for torch.compile inplace output."
)
softmax_stats = torch.empty((num_tokens, self.num_heads, 2),
device=q.device,
dtype=torch.float32)
attn_output = self.attn.forward(
q,
k,
v,
attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
attention_window_size=attention_window_size,
attention_mask_data=attention_mask_data,
softmax_stats_tensor=softmax_stats,
attention_sinks=attention_sinks)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
attn_output = self._helix_post_process(attn_output, softmax_stats)
return attn_output, None
out_scale = None
out_scale_sf = None
# Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
# Also don't set out_scale if LoRA is active - LoRA grouped_gemm doesn't support FP8
if self._use_quantize_output() and not has_lora:
out_scale = self.o_proj.inv_input_scale
out_scale_sf = self.o_proj.input_scale
kv_scales_sf = None
kv_scales_sf_inv = None
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp4_kv_cache(
):
kv_scales_sf = self.qkv_proj.kv_scales
kv_scales_sf_inv = self.qkv_proj.inv_kv_scales
attn_output = self.attn.forward(
q,
k,
v,
attn_metadata,
out_scale=out_scale,
out_scale_sf=out_scale_sf,
kv_scales_sf=kv_scales_sf,
kv_scales_sf_inv=kv_scales_sf_inv,
attention_mask=attention_mask,
mrope_config=mrope_config,
attention_window_size=attention_window_size,
attention_mask_data=attention_mask_data,
output=output[:num_tokens, :] if output is not None else None,
output_sf=output_sf,
attention_sinks=attention_sinks)
if isinstance(attn_output, tuple):
assert len(
attn_output
) == 2, "attn_output should be a tuple of (output, output_sf)"
return attn_output[0], attn_output[1]
return attn_output, None
def forward_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
mrope_config: Optional[dict],
attention_sinks: Optional[torch.Tensor] = None,
has_lora: bool = False,
):
mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())
if use_custom_inplace_op:
outputs = create_attn_outputs(q, attention_mask, self.layer_idx_str)
assert len(outputs) == 1 or len(outputs) == 2
output = outputs[0]
output_sf = outputs[1] if len(outputs) == 2 else None
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
attention_sinks,
self.layer_idx_str,
output,
output_sf,
)
else:
output, output_sf = self._attn_impl(q,
k,
v,
attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
attention_sinks=attention_sinks,
has_lora=has_lora)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
return output
def forward(
self,
position_ids: Optional[torch.IntTensor],
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass for the Attention module.
Args:
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
attention_mask (AttentionMask): The attention mask type.
mrope_config (Optional[dict]): The MROPE configuration.
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
lora_params (Optional[dict]): The LoRA parameters.
attention_window_size (Optional[int]): The attention window size.
attention_mask_data (Optional[torch.Tensor]): The attention mask data.
Returns:
torch.Tensor: The output tensor.
"""
hidden_states = _helix_cp_allgather_input(hidden_states, attn_metadata,
self.mapping, self.layer_idx)
qkv = self.qkv_proj(hidden_states)
if bool(lora_params):
qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params,
self.layer_idx)
if qkv_lora is not None:
qkv = qkv + qkv_lora
qkv_lora = self.fused_qkv_lora(hidden_states, lora_params,
self.layer_idx)
if qkv_lora is not None:
qkv = qkv + qkv_lora
if self.attn_output_gate:
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
orig_shape = q_gate.shape[:-1]
# Single line: view -> chunk -> reshape both q and gate
q, gate = [
t.reshape(*orig_shape, -1) for t in torch.chunk(
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
]
else:
q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
q, k, v = self.convert_qkv(q, k, v)
if attention_sinks is not None:
assert self.attn_backend == "TRTLLM", "Attention sinks are only supported for TRTLLM backend."
attn_output = self.forward_impl(q,
k,
v,
attn_metadata,
attention_mask,
attention_window_size,
attention_mask_data,
mrope_config=mrope_config,
attention_sinks=attention_sinks,
has_lora=bool(lora_params))
if self.attn_output_gate:
gate = torch.sigmoid(gate)
attn_output = attn_output * gate
attn_output = _helix_cp_output_projection(self.o_proj, attn_output,
attn_metadata,
all_reduce_params,
self.mapping, self.mapping_o,
self.layer_idx, lora_params)
return attn_output
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
"""
Apply RoPE to the query and key.
Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
Args:
q (torch.Tensor): The query tensor.
k (Optional[torch.Tensor]): The key tensor.
v (Optional[torch.Tensor]): The value tensor.
position_ids (torch.Tensor): The position IDs of each token for RoPE.
Returns:
tuple: A tuple of (q, k, v).
"""
# If RoPE is fused into the attention OP, do not apply RoPE here.
if not self.rope_fusion and position_ids is not None:
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
return q, k, v
def apply_qk_norm(self, q, k):
raise NotImplementedError(
f"QK norm is not implemented for {self.__class__.__name__}. "
"Please override the `apply_qk_norm` method in the subclass.")
@torch.library.custom_op("trtllm::mla_custom_op_inplace",
mutates_args=("output", ))
def mla_custom_op_inplace(
hidden_states: torch.Tensor,
position_ids: Optional[torch.Tensor],
layer_idx: str,
output: torch.Tensor,
latent_cache_gen: Optional[torch.Tensor],
) -> None:
metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
mla_layer.forward_impl(position_ids,
hidden_states,
metadata,
output=output,
latent_cache_gen=latent_cache_gen)
def fp8_block_scaling_bmm_out(
mat1: torch.Tensor,
mat2_fp8: torch.Tensor,
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
use_cute_dsl_blockscaling_bmm: bool = False,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
output)
out.copy_(output)
elif sm_version == 120:
mat1_fp8, mat1_scale = fp8_utils.per_token_quant_and_transform(
mat1, need_permute102=True)
output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
output)
out.copy_(output)
elif is_sm_100f(sm_version):
if use_cute_dsl_blockscaling_bmm:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
out)
mat1_scale = None
else:
torch.bmm(mat1.transpose(0, 1),
mat2_dequant.transpose(1, 2),
out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
class MLA(nn.Module):
def __init__(
self,
*,
hidden_size: int,