-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathattention.py
More file actions
1721 lines (1521 loc) · 67.1 KB
/
attention.py
File metadata and controls
1721 lines (1521 loc) · 67.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from __future__ import annotations
import copy
import inspect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Protocol, Tuple, Union
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.jit import jit_fuser
from megatron.core.models.common.embeddings.rope_utils import (
apply_rotary_pos_emb,
apply_rotary_pos_emb_with_cos_sin,
)
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.torch_norm import LayerNormBuilder
from megatron.core.typed_torch import apply_module, not_none
from megatron.core.utils import (
deprecate_inference_params,
divide,
get_pg_rank,
get_pg_size,
is_fa_min_version,
is_te_min_version,
is_using_quantization_scales,
nvtx_range_pop,
nvtx_range_push,
)
from ..models.common.embeddings.yarn_rotary_pos_embedding import (
_yarn_get_concentration_factor_from_config,
)
from .enums import AttnMaskType, CudaGraphScope
from .transformer_config import TransformerConfig
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
from flash_attn_3.flash_attn_interface import _flash_attn_forward
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn3_with_kvcache,
)
HAVE_FA3 = True
except ImportError as e:
HAVE_FA3 = False
if not HAVE_FA3:
try:
from flashattn_hopper.flash_attn_interface import _flash_attn_forward
from flashattn_hopper.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn3_with_kvcache,
)
HAVE_FA3 = True
except ImportError as e:
pass
try:
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
HAVE_FMLA = True
except ImportError:
flash_mla_with_kvcache = None
get_mla_metadata = None
HAVE_FMLA = False
from megatron.core.transformer.transformer_config import MLATransformerConfig
try:
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
except:
flash_attn_varlen_func = None
flash_attn_with_kvcache = None
try:
import transformer_engine # pylint: disable=unused-import
HAVE_TE = True
from megatron.core.extensions.transformer_engine import (
SplitAlongDim,
TELinear,
set_save_original_input,
)
except ImportError:
HAVE_TE = False
SplitAlongDim, TELinear, set_save_original_input = None, None, None
try:
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
HAVE_FUSED_QKV_ROPE = True
except ImportError:
HAVE_FUSED_QKV_ROPE = False
FMLA_REQUIRED_BLOCK_SIZE = 64
class LinearQkv(Protocol):
"""Protocol for linear_qkv modules."""
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
"""Applies linear_qkv."""
...
def backward_dw(self) -> None:
"""Backward pass for the linear_qkv module."""
...
class LinearQkvBuilder(Protocol):
"""Protocol for building linear_qkv layers."""
def __call__(
self,
input_size: int,
output_size: int,
/,
*,
config: TransformerConfig,
init_method: Callable[[torch.Tensor], None],
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str,
tp_group: torch.distributed.ProcessGroup | None = None,
) -> LinearQkv: ...
class LinearLayer(Protocol):
"""Protocol for linear_q and linear_kv modules."""
def forward(self, input: Tensor, /) -> Tuple[Tensor, object]:
"""Applies linear_q/linear_kv."""
...
class LinearLayerBuilder(Protocol):
"""Protocol for building linear_q and linear_kv layers."""
def __call__(
self,
input_size: int,
output_size: int,
/,
*,
config: TransformerConfig,
init_method: Callable[[torch.Tensor], None],
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
) -> LinearLayer: ...
class CoreAttention(Protocol):
"""Protocol for core_attention modules."""
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Optional[Tensor],
/,
*,
attn_mask_type: AttnMaskType,
attention_bias: Optional[Tensor],
packed_seq_params: Optional[PackedSeqParams],
) -> Tensor:
"""Applies dot product attention."""
...
class CoreAttentionBuilder(Protocol):
"""Protocol for building core_attention layers."""
def __call__(
self,
*,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
cp_comm_type: Optional[str],
softmax_scale: Optional[float],
pg_collection: Optional[ProcessGroupCollection],
) -> CoreAttention: ...
@dataclass
class SelfAttentionSubmodules:
"""
Configuration class for specifying the submodules of a self-attention.
"""
linear_qkv: LinearQkvBuilder
core_attention: CoreAttentionBuilder
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: LayerNormBuilder | None = None
k_layernorm: LayerNormBuilder | None = None
@dataclass
class CrossAttentionSubmodules:
"""
Configuration class for specifying the submodules of a cross-attention.
"""
linear_q: LinearLayerBuilder
linear_kv: LinearLayerBuilder
core_attention: CoreAttentionBuilder
linear_proj: Union[ModuleSpec, type] = None
class Attention(MegatronModule, ABC):
"""Attention layer abstract class.
This layer only contains common modules required for the "self attn" and
"cross attn" specializations.
"""
def __init__(
self,
config: TransformerConfig,
submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
cp_comm_type: str | None = None,
pg_collection: ProcessGroupCollection | None = None,
pp_layer_offset: Optional[int] = None,
):
super().__init__(config=config)
self.config = config
self.layer_number = layer_number
self._pp_layer_offset = pp_layer_offset
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
self.batch_invariant_mode = config.batch_invariant_mode
assert self.config.kv_channels is not None
assert self.config.num_query_groups is not None
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'cp'])
else:
assert hasattr(
pg_collection, 'tp'
), "Attention pg_collection must have tp process group"
assert hasattr(
pg_collection, 'cp'
), "Attention pg_collection must have cp process group"
self.pg_collection = pg_collection
self.tp_group = pg_collection.tp
# Per attention head and per partition values
world_size = get_pg_size(self.pg_collection.tp)
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
if self.config.num_query_groups < world_size:
# When num_kv_heads < tp_size, each TP rank (post AG) initially produces
# activations for 1 kv_head and (num_q_heads / num_kv_heads) q_heads.
# We then pull out the appropriate (num_q_heads / tp_size) q_heads.
self.num_query_groups_per_partition = 1
self.num_attention_heads_per_partition = divide(
self.config.num_attention_heads, self.config.num_query_groups
)
else:
# When num_kv_heads >= tp_size, each TP rank produces activations for
# (num_kv_heads / tp_size) kv_heads and (num_q_heads / tp_size) q_heads.
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.num_attention_heads_per_partition = divide(
self.config.num_attention_heads, world_size
)
self.world_size = world_size
# To support both CUDA Graphs and key value with different hidden size
self.key_hidden_size = self.hidden_size_per_attention_head
self.val_hidden_size = self.hidden_size_per_attention_head
if self.config.num_query_groups < world_size:
# TE throws an assertion error if num_kv_heads / num_query_groups
# is not divisible by TP size.
# TODO(rwaleffe/dnarayanan): Clean this up eventually.
tmp_config = copy.deepcopy(self.config)
tmp_config.num_query_groups = world_size
else:
tmp_config = self.config
self.core_attention = submodules.core_attention(
config=tmp_config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
cp_comm_type=cp_comm_type,
softmax_scale=self.config.softmax_scale,
pg_collection=self.pg_collection,
)
self.checkpoint_core_attention = (
self.config.recompute_granularity == 'selective'
and "core_attn" in self.config.recompute_modules
)
self.offload_qkv_linear = (
self.config.fine_grained_activation_offloading
and "qkv_linear" in self.config.offload_modules
)
self.offload_core_attention = (
self.config.fine_grained_activation_offloading
and "core_attn" in self.config.offload_modules
)
self.offload_attn_proj = (
self.config.fine_grained_activation_offloading
and "attn_proj" in self.config.offload_modules
)
# Output.
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
tp_group=self.pg_collection.tp,
)
if (
HAVE_TE
and isinstance(self.linear_proj, TELinear)
and (
(
self.config.fp8
and self.config.fp8_recipe != 'delayed'
and is_te_min_version("2.6.0dev0")
)
or (self.config.fp4 and is_te_min_version("2.7.0.dev0"))
)
):
# For fp8/fp4 training, the output of the fused core_attn is saved by itself, and
# linear_proj also saves the quantized tensor of this output. Here we set the
# linear_proj to save the original input tensors to avoid the extra memory usage of
# the quantized tensor.
set_save_original_input(self.linear_proj)
def _checkpointed_attention_forward(
self,
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
attention_bias=None,
packed_seq_params=None,
):
"""Forward method with selective activation checkpointing."""
def custom_forward(*inputs):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attention_mask = inputs[3]
attn_mask_type = inputs[5]
attn_mask_type = AttnMaskType(attn_mask_type.item())
output_ = apply_module(self.core_attention)(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
return output_
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
hidden_states = tensor_parallel.checkpoint(
custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
dim,
dtype=dtype,
device=torch.cuda.current_device(),
)
def _get_pp_layer_offset_for_inference(self):
"""Return the pipeline parallel layer offset for inference.
When pp_layer_offset was explicitly provided (e.g. by MambaBlock for
hybrid models using --hybrid-layer-pattern with fVPP), use that value
directly. Otherwise fall back to the standard computation which assumes
uniform layer distribution across pipeline stages.
"""
if self._pp_layer_offset is not None:
return self._pp_layer_offset
assert (
self.config.virtual_pipeline_model_parallel_size is None
), "Virtual pipeline parallelism is not supported for inference"
# Import here to avoid circular imports
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
return get_transformer_layer_offset(
self.config, vp_stage=None, pp_rank=get_pg_rank(self.pg_collection.pp)
)
def _adjust_key_value_for_inference(
self,
inference_context: BaseInferenceContext,
query: Tensor,
key: Tensor,
value: Tensor,
rotary_pos_emb: Tensor,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
rotary_pos_cos_sin: Optional[Tensor] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, AttnMaskType, Tensor]:
"""
Saves the generated key and value tensors to the end of the buffers in inference_context.
Returns the full size keys and values from the provided inference_context, as well as
adjusted rotary_pos_emb.
Args:
query (Tensor): Query tensor.
key (Tensor): Key tensor.
value (Tensor): Value tensor.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary
embedding tensor(s).
rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine.
Currently used exclusively for inference with dynamic batching and flashinfer RoPE.
sequence_len_offset (Optional[int]): Sequence length offset used for
inference CUDA graphs.
Return:
Tuple of: query, key, value, rotary_pos_emb, attn_mask_type, block_table.
"""
inference_context = deprecate_inference_params(inference_context, inference_params)
attn_mask_type = self.attn_mask_type
if inference_context is None:
return query, key, value, rotary_pos_emb, attn_mask_type, None
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_context.is_static_batching():
if self.layer_number not in inference_context.key_value_memory_dict:
inf_max_seq_length = inference_context.max_sequence_length
inf_max_batch_size = inference_context.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, self.key_hidden_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, self.val_hidden_size, value.dtype
)
inference_context.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
# Get the pre-allocated buffers for this layer
inference_key_memory, inference_value_memory = (
inference_context.key_value_memory_dict[self.layer_number]
)
if (
not inference_context.is_static_batching() or inference_context.sequence_len_offset > 0
) and (not self.training or not is_te_min_version("2.2.0")):
# This should mean that we are past the prompt forward_step
# and so we need to turn off masking
# Note: in ModelOpt, we may use inference_context for speculative decoding
# in training. In that case, we do not want to turn off masking as we need
# customized attention mask for speculative decoding.
attn_mask_type = AttnMaskType.no_mask
if inference_context.is_static_batching():
batch_start = inference_context.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0), (
"Current sequence length is longer than expected maximum sequence length! "
"Increase inference_max_seq_length."
)
if self.config.flash_decode:
rotary_pos_cos_q = None
rotary_pos_sin_q = None
rotary_pos_cos_k = None
rotary_pos_sin_k = None
assert inference_context.is_static_batching()
if (
inference_context.sequence_len_offset > 0 and rotary_pos_cos is not None
): # Decode phase, not prefill
rotary_pos_cos_q = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
elif rotary_pos_cos is not None: # Prefill
rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
rotary_pos_sin_k = rotary_pos_sin[:sequence_end]
# Flash Decoding assumes that the keys stored in the KV Cache already have RoPE applied.
# Apply RoPE before we store the keys to make it compatible with flash decoding kernel
if rotary_pos_sin_q is not None and rotary_pos_sin_k is not None:
key = apply_rotary_pos_emb_with_cos_sin(
key,
rotary_pos_cos_k,
rotary_pos_sin_k,
rotary_interleaved=self.config.rotary_interleaved,
)
query = apply_rotary_pos_emb_with_cos_sin(
query,
rotary_pos_cos_q,
rotary_pos_sin_q,
rotary_interleaved=self.config.rotary_interleaved,
)
else:
rotary_pos_cos_q = None
rotary_pos_sin_q = None
# Adjust rotary embeddings.
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if inference_context.is_static_batching():
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
else:
pass
rotary_pos_emb = (q_pos_emb, k_pos_emb)
block_table = None
if inference_context.is_static_batching():
# Copy key and values.
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
else:
pp_layer_offset = self._get_pp_layer_offset_for_inference()
# Apply rotary embeddings before appending KV cache.
if inference_context.use_flashinfer_fused_rope and (rotary_pos_cos_sin is not None):
query, key = inference_context.apply_fused_qk_rotary_emb(
query, key, rotary_pos_cos_sin, self.config
)
elif rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
key = inference_context.apply_rotary_emb_key(
key, k_pos_emb, self.config, self.pg_collection.cp
)
rotary_pos_emb = (q_pos_emb, None) # key rotary emb has been applied
# Append key/value data tensors to cache.
inference_context.append_key_value_cache(
self.layer_number - pp_layer_offset, key, value
)
_, max_seqlen_q = inference_context.cu_query_lengths()
# Read key/value *pointer* tensors from cache.
key, value, block_table = inference_context.key_value_cache(
self.layer_number - pp_layer_offset
)
block_size_tokens = key.size(1)
# Do not use absorption when block size doesn't match what's expected by FlashMLA.
if getattr(self.config, "cache_mla_latents", None) and (
max_seqlen_q > 1 or block_size_tokens != FMLA_REQUIRED_BLOCK_SIZE
):
# Doing unabsorbed MLA Attention with cached mla latents (prefill/mixed mode)
kv_cache = key
# Uncompress the KV cache for prefill/mixed mode
key, value = self.uncompress_kv_from_cache(kv_cache)
return query, key, value, rotary_pos_emb, attn_mask_type, block_table
@abstractmethod
def get_query_key_value_tensors(
self,
hidden_states: Tensor,
key_value_states: Tensor | None,
output_gate: bool = False,
split_qkv: bool = True,
) -> (
tuple[Tensor, Tensor, Tensor, Tensor]
| tuple[Tensor, Tensor, Tensor]
| tuple[Tensor, list[int]]
):
"""
This method needs to be implemented based on whether the derived class
is "self-attn" or "cross-attn".
"""
def flash_decode(
self,
sequence_len_offset: Tensor,
query_layer: Tensor,
key_layer: Tensor,
value_layer: Tensor,
inference_key_memory: Tensor,
inference_value_memory: Tensor,
rotary_cos: Tensor,
rotary_sin: Tensor,
rotary_interleaved: bool = False,
) -> tuple[Tensor, Tensor]:
"""
The flash decoding kernel will do the following in a single execution:
1. Compute RoPE embedding with precomputed cos & sin tensors
2. Update the KV Cache
3. Performs the flash attention operation
"""
assert flash_attn_with_kvcache is not None, (
"Flash Decoding requires the flash_attn_with_kvcache kernel, "
"available in the flash-attn package."
)
q = query_layer.permute(1, 0, 2, 3)
k = key_layer.permute(1, 0, 2, 3)
v = value_layer.permute(1, 0, 2, 3)
k_cache = inference_key_memory.permute(1, 0, 2, 3)
v_cache = inference_value_memory.permute(1, 0, 2, 3)
if rotary_cos is not None:
rotary_cos = rotary_cos.to(query_layer.dtype)
if rotary_sin is not None:
rotary_sin = rotary_sin.to(query_layer.dtype)
out = flash_attn_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
k=k,
v=v,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=sequence_len_offset,
rotary_interleaved=rotary_interleaved,
)
return out
def _flash_attention_3_forward_wrapper(
self,
q: Tensor,
k: Tensor,
v: Tensor,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
seqlens_k,
block_table,
softmax_scale,
):
"""
Wrapper for calling the FA3 _flash_attn_forward function.
Handles argument conversion for different versions of the _flash_attn_forward API.
"""
candidate_kwargs = {
"q": q,
"k": k,
"v": v,
"k_new": None,
"v_new": None,
"qv": None,
"out": None,
"out_": None,
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_k": None,
"cu_seqlens_k_new": None,
"seqused_q": None,
"seqused_k": seqlens_k,
"max_seqlen_q": max_seqlen_q,
"max_seqlen_k": max_seqlen_k,
"page_table": block_table,
"kv_batch_idx": None,
"leftpad_k": None,
"rotary_cos": None,
"rotary_sin": None,
"seqlens_rotary": None,
"q_descale": None,
"k_descale": None,
"v_descale": None,
"softmax_scale": softmax_scale,
"causal": True,
"attention_chunk": 0,
"softcap": 0.0,
"window_size": (-1, -1),
"window_size_left": -1,
"window_size_right": -1,
"rotary_interleaved": True,
"scheduler_metadata": None,
"num_splits": 0 if not self.batch_invariant_mode else 1,
"pack_gqa": None,
"sm_margin": 0,
}
# Parse the expect argument names from the function signature
if inspect.isfunction(_flash_attn_forward):
sig = inspect.signature(_flash_attn_forward)
else:
assert isinstance(_flash_attn_forward, torch._library.custom_ops.CustomOpDef)
sig = inspect.signature(_flash_attn_forward._init_fn)
valid_kwargs = set(sig.parameters.keys())
final_kwargs = {k: candidate_kwargs[k] for k in valid_kwargs if k in candidate_kwargs}
output_total, *unused = _flash_attn_forward(**final_kwargs)
return output_total
def flash_decode_and_prefill(
self,
q: Tensor,
k: Tensor,
v: Tensor,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
seqlens_k,
block_table,
is_decode_only,
) -> Tensor:
"""Flash attention kernel for mixed decode and prefill samples.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
v (Tensor): Value tensor.
max_seqlen_q (int): Query total sequence length.
max_seqlen_k (int): Key total sequence length.
cu_seqlens_q (Tensor): Cumulative query sequence lengths.
cu_seqlens_k (Tensor): Cumulative key sequence lengths.
seqlens_k (Tensor): key sequence lengths.
block_table (Tensor): KV cache block ids for all samples.
is_decode_only (bool): True if batch is decode only.
Return:
(Tensor) Attention output.
"""
assert not self.training
assert block_table is not None
# Flash attn kernel.
if not is_decode_only:
q = q.squeeze(1)
if getattr(self, "softmax_scale", None) is not None:
softmax_scale = self.softmax_scale
else:
softmax_scale = q.shape[-1] ** -0.5
if HAVE_FA3:
# TODO(ksanthanam): Replace with call to flash_attn_varlen_func once
# it accepts block_table
output_total = self._flash_attention_3_forward_wrapper(
q,
k,
v,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
seqlens_k,
block_table,
softmax_scale,
)
else:
assert (
self.batch_invariant_mode is False
), "Batch invariant mode is not supported for flash attention 2"
output_total = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_table,
)
output_total = output_total.unsqueeze(1)
else: # decode only
# If using MLA we use the FlashMLA kernel when possible.
block_size_tokens = k.size(1)
if (
isinstance(self.config, MLATransformerConfig)
# Only use FlashMLA when the block size matches
and block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE
):
softmax_scale = self.softmax_scale
num_heads_k = 1 # Only a single head for MLA Flash
seq_len_q = 1 # Sequence length is 1 for decode
num_heads_q = self.num_attention_heads_per_partition
num_heads_per_head_k = seq_len_q * num_heads_q // num_heads_k
cache_seqlens = seqlens_k
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, # cumulative key-lengths
num_heads_per_head_k, # decode-only lengths
num_heads_k, # per-head dim of V
)
head_dim_v = self.config.kv_lora_rank
kv_cache = k.unsqueeze(-2)
output_total, softmax_lse = flash_mla_with_kvcache(
q,
kv_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
softmax_scale=softmax_scale,
causal=True,
)
else:
flash_attn_args = {
"q": q,
"k_cache": k,
"v_cache": v,
"cache_seqlens": seqlens_k,
"causal": True,
"page_table" if HAVE_FA3 else "block_table": block_table,
"num_splits": 0 if not self.batch_invariant_mode else 1,
"softmax_scale": getattr(self, "softmax_scale", self.config.softmax_scale),
}
if HAVE_FA3:
output_total = flash_attn3_with_kvcache(**flash_attn_args)
else:
assert (
not self.batch_invariant_mode
), "Batch invariant mode is not supported for flash attention 2"
output_total = flash_attn_with_kvcache(**flash_attn_args)
return output_total
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
rotary_pos_cos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
) -> tuple[Tensor, Tensor]:
"""
Perform a forward pass through the attention module.
Args:
hidden_states (Tensor): Hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Optional[Tensor]): Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]): Inference context that manages
KV cache.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary
embedding tensor(s).
rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine.
Currently used exclusively for inference with dynamic batching and flashinfer RoPE.
attention_bias (Optional[Tensor]): Attention bias.
packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.
sequence_len_offset (Optional[int]): Sequence length offset used for
inference CUDA graphs.
Return:
(Tuple[Tensor, Tensor]) Attention output and bias.
"""
# Check if we need to skip RoPE
# no_rope is 0-indexed array and self.layer_number is 1-indexed
no_rope = (
self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False
)
if no_rope:
rotary_pos_emb = None
inference_context = deprecate_inference_params(inference_context, inference_params)
if inference_context and inference_context.is_dynamic_batching():
assert HAVE_FA3 or is_fa_min_version(
"2.7.3"
), "flash attn verion v2.7.3 and above is required for dynamic batching."
# hidden_states: [sq, b, h]
is_inference_mode = inference_context is not None and not self.training
# is_using_flash_decode - True is we are using the static inference engine with flash decode
is_using_flash_decode = is_inference_mode and self.config.flash_decode
# is_using_flashinfer_rope - True if we are using the dynamic inference engine
# with flashinfer fused rope
is_using_flashinfer_rope = is_inference_mode and (
not inference_context.is_static_batching()
and inference_context.use_flashinfer_fused_rope
)
if is_using_flash_decode or is_using_flashinfer_rope:
# flash decode and flash-infer fused rope use rotary_pos_cos and rotary_pos_sin
rotary_pos_emb = None
else:
assert rotary_pos_cos is None and rotary_pos_sin is None
# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
split_qkv = (self.attention_type == "cross") or not all(
[
not self.config.test_mode,
self.config.fused_single_qkv_rope,
inference_context is None,
packed_seq_params is None,
(
rotary_pos_emb is not None
and rotary_pos_emb[0] is not None
and rotary_pos_emb[1] is not None
),
not self.config.flash_decode,
HAVE_FUSED_QKV_ROPE,
self.q_layernorm is None or isinstance(self.q_layernorm, IdentityOp),
self.k_layernorm is None or isinstance(self.k_layernorm, IdentityOp),
]
)
# Check if fused_single_qkv_rope is requested but either unavailable or not
# supported for the current use case.
if self.attention_type != "cross":
assert not (
self.config.fused_single_qkv_rope and split_qkv
), "fused_single_qkv_rope requested but not available/supported for the config."