forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer_layer.py
More file actions
1515 lines (1324 loc) · 69 KB
/
transformer_layer.py
File metadata and controls
1515 lines (1324 loc) · 69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from __future__ import annotations
import functools
import logging
import warnings
from abc import ABC
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import torch
import torch.distributed
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.enums import CudaGraphScope, LayerType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.module import GraphableMegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.torch_norm import LayerNormBuilder
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.typed_torch import apply_module, copy_signature
from megatron.core.utils import (
deprecate_inference_params,
get_pg_rank,
is_te_min_version,
log_single_rank,
make_viewless_tensor,
nvtx_range_pop,
nvtx_range_push,
)
if TYPE_CHECKING:
from megatron.core.inference.contexts import BaseInferenceContext
logger = logging.getLogger(__name__)
def get_transformer_layer_offset(
config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None
):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
if pp_rank is None:
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
is_first_pp_stage = pp_rank == 0
if config.pipeline_model_parallel_size > 1:
if config.pipeline_model_parallel_layout:
offset = config.pipeline_model_parallel_layout.get_layer_offset(
layer_type=LayerType.decoder, vp_stage=vp_stage
)
elif (
config.num_layers_in_first_pipeline_stage is not None
or config.num_layers_in_last_pipeline_stage is not None
):
# Calculate number of pipeline stages to distribute the remaining Transformer
# layers after deducting the Transformer layers in the first or the last stages
middle_pipeline_stages = config.pipeline_model_parallel_size
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
config.num_layers_in_first_pipeline_stage,
config.num_layers_in_last_pipeline_stage,
)
]
)
# Calculate layers to distribute in each pipeline stage. If the
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
# are not set, we will not enable uneven pipeline. All layers will be treated
# as middle layers.
num_layers_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage
)
num_layers_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage
)
middle_num_layers = (
config.num_layers
- num_layers_in_first_pipeline_stage
- num_layers_in_last_pipeline_stage
)
middle_pipeline_rank = (
pp_rank if config.num_layers_in_first_pipeline_stage is None else pp_rank - 1
)
if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
assert (
vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
# num_layers_in_last_pipeline_stage are not set, all pipeline stages
# will be treated as middle pipeline stages in the calculation
num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage // vp_size
)
num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage // vp_size
)
num_layers_per_virtual_model_chunk_in_middle_pipeline_stage = (
middle_num_layers // vp_size
)
# First stage + middle stage + last stage
total_virtual_chunks = (
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ num_layers_per_virtual_model_chunk_in_middle_pipeline_stage
+ num_layers_per_virtual_model_chunk_in_last_pipeline_stage
)
# Calculate the layer offset with interleaved uneven pipeline parallelism
if pp_rank == 0:
offset = vp_stage * total_virtual_chunks
else:
offset = (
vp_stage * total_virtual_chunks
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ middle_pipeline_rank
* (
num_layers_per_virtual_model_chunk_in_middle_pipeline_stage
// middle_pipeline_stages
)
)
else:
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
else:
num_layers_per_pipeline_rank = 0
if pp_rank == 0:
offset = 0
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + num_layers_in_first_pipeline_stage
else:
num_layers = config.num_layers
# Increase the number of layers by one if we include the embedding (loss)
# layer into pipeline parallelism partition and placement
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
# import here to avoid circular import
from megatron.core.pipeline_parallel.utils import is_vp_first_stage
if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
assert (
vp_stage is not None
), "vp_stage must be provided if virtual pipeline model parallel size is set"
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = num_layers // vp_size
offset = vp_stage * total_virtual_chunks + (pp_rank * num_layers_per_virtual_rank)
# Reduce the offset of embedding layer from the total layer number
if config.account_for_embedding_in_pipeline_split and not (
is_vp_first_stage(vp_stage, vp_size) and is_first_pp_stage
):
offset -= 1
else:
offset = pp_rank * num_layers_per_pipeline_rank
# Reduce the offset of embedding layer from the total layer number
if config.account_for_embedding_in_pipeline_split and not (
is_vp_first_stage(vp_stage, vp_size) and is_first_pp_stage
):
offset -= 1
else:
offset = 0
return offset
@dataclass
class TransformerLayerSubmodules:
"""
Configuration class for specifying the submodules of a transformer layer.
This class defines the structure and default implementations for various
components of a transformer layer, allowing for flexible customization
of the layer's architecture.
Args:
input_layernorm: Specification for the input layer normalization.
self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism.
self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after self-attention.
pre_cross_attn_layernorm: Specification for the layer
normalization before cross-attention.
cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism.
cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after cross-attention.
pre_mlp_layernorm: Specification for the layer normalization
before the MLP.
mlp (Union[ModuleSpec, type]): Specification for the MLP in Dense layer.
mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after the MLP.
sharded_state_dict_keys_map (Dict[str, str]): Mapping for sharded tensor keys to be applied
in the `sharded_state_dict` method.
"""
input_layernorm: LayerNormBuilder = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_cross_attn_layernorm: LayerNormBuilder = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm: LayerNormBuilder = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class BaseTransformerLayer(ABC):
"""A common parent class for `TransformerLayer` like implementations.
A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
`TransformerLayer` in this file and possibly other `TransformerLayer`
implementations that aim to use `TransformerBlock` as the base module.
The main purpose is to check if any layer (or module) provided in the spec
is a subclass of this class to allow fanning-out of that spec for all the
layers in the `TransformerBlock`. See `_get_block_submodules` method
implementation in `transformer_block.py` file for more details.
"""
def __init__(self):
pass
class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
is_mtp_layer: bool = False,
add_layer_offset: bool = True,
pp_layer_offset: Optional[int] = None,
):
self.submodules_config = submodules
super().__init__(config=config, vp_stage=vp_stage)
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.tp_group = pg_collection.tp
# MTP inner layers use their own layer numbering (starting from 1 within each MTP depth),
# so they should NOT add the decoder layer offset. The router.py handles MTP layer
# numbering separately by adding config.num_layers to distinguish MTP layers from decoder
# layers in the aux loss tracker.
#
# When add_layer_offset is False, the caller has already included the correct offset
# in layer_number (e.g. when using --hybrid-layer-pattern with fVPP).
if is_mtp_layer or not add_layer_offset:
self.layer_number = layer_number
else:
self.layer_number = layer_number + get_transformer_layer_offset(
self.config, vp_stage, get_pg_rank(pg_collection.pp)
)
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
self.is_mtp_layer = is_mtp_layer
# [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = submodules.input_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
attention_optional_kwargs = {}
if config.context_parallel_size > 1 and config.cp_comm_type is not None:
if isinstance(config.cp_comm_type, list):
# layer_number is 1-indexed, so we need to subtract 1 to get the correct index
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[
self.layer_number - 1
]
else:
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type
attention_optional_kwargs["pg_collection"] = pg_collection
if pp_layer_offset is not None:
attention_optional_kwargs["pp_layer_offset"] = pp_layer_offset
# [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention,
config=self.config,
layer_number=self.layer_number,
**attention_optional_kwargs,
)
# [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)
# [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = submodules.pre_mlp_layernorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 8: MLP block]
additional_mlp_kwargs = {}
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import TEFusedMLP
from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer
# MLP expects tp_group but MoELayer expects pg_collection to be passed in.
# We can change MLP to accept pg_collection but it makes the logic implicit
# The conditional below is to make the logic explicit
# if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs
if isinstance(submodules.mlp, ModuleSpec):
if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP):
additional_mlp_kwargs["pg_collection"] = pg_collection
# Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers.
if submodules.mlp.module == MoELayer:
additional_mlp_kwargs["is_mtp_layer"] = self.is_mtp_layer
elif submodules.mlp.module == MLP:
assert hasattr(
pg_collection, 'tp'
), 'TP process group is required for MLP in TransformerLayer'
additional_mlp_kwargs["tp_group"] = pg_collection.tp
elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP:
assert hasattr(
pg_collection, 'tp'
), 'TP process group is required for TEFusedMLP in TransformerLayer'
additional_mlp_kwargs["tp_group"] = pg_collection.tp
else:
log_single_rank(
logger,
logging.WARNING,
f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.",
)
self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
self.is_moe_layer = isinstance(self.mlp, MoELayer)
self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
if self.config.recompute_granularity == 'selective':
assert self.config.recompute_modules is not None
if "layernorm" in self.config.recompute_modules:
if not isinstance(self.input_layernorm, IdentityOp):
self.recompute_input_layernorm = True
if self.config.fp8 or self.config.fp4:
self.self_attention.set_for_recompute_input_layernorm()
def can_recompute_pre_mlp_layernorm_for_cudagraph():
if (
not self.is_moe_layer
or CudaGraphScope.moe_router not in self.config.cuda_graph_scope
or self.config.cuda_graph_impl == "local"
):
# Not a MoE layer, or not capturing the router part.
return True
if (
self.config.moe_shared_expert_intermediate_size is not None
and self.config.moe_shared_expert_overlap
):
# If shared expert overlap is used, we cannot make the pre-mlp layernorm
# recomputation, because the shared expert takes the layernorm output as
# input, and it is outside of the CUDA graph scope.
log_single_rank(
logger,
logging.WARNING,
"pre_mlp_layernorm recompute is not supported with moe router "
"cudagraph + shared expert overlap. Disabling pre_mlp_layernorm "
"recompute.",
)
return False
if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and (
self.config.moe_token_dispatcher_type == "alltoall"
or self.config.moe_latent_size
):
# Only when capturing the preprocess part and using alltoall token
# dispatcher or latent MoE can we make the pre-mlp layernorm recomputation.
# Because in other cases the layernorm output returns directly as one of the
# outputs of the cudagraph, which will be allocated a static buffer, thus
# not able to be released.
return True
log_single_rank(
logger,
logging.WARNING,
"pre_mlp_layernorm recompute is only supported with moe router + "
"preprocess cudagraph will alltoall token dispatcher or latent MoE. "
"Disabling pre_mlp_layernorm recompute.",
)
return False
if (
not isinstance(self.pre_mlp_layernorm, IdentityOp)
and can_recompute_pre_mlp_layernorm_for_cudagraph()
):
self.recompute_pre_mlp_layernorm = True
if self.config.fp8 or self.config.fp4:
if isinstance(self.mlp, MoELayer):
self.mlp.set_for_recompute_pre_mlp_layernorm()
else:
from megatron.core.extensions.transformer_engine import (
set_save_original_input,
)
set_save_original_input(self.mlp.linear_fc1)
if "mlp" in self.config.recompute_modules:
if not self.is_moe_layer:
self.recompute_mlp = True
self.offload_attn_norm = (
self.config.fine_grained_activation_offloading
and "attn_norm" in self.config.offload_modules
and not isinstance(self.input_layernorm, IdentityOp)
)
self.offload_mlp_norm = (
self.config.fine_grained_activation_offloading
and "mlp_norm" in self.config.offload_modules
and not isinstance(self.pre_mlp_layernorm, IdentityOp)
)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def create_mcore_cudagraph_manager(self, config):
"""Register the transformer layer for cudagraphs."""
from megatron.core.transformer.cuda_graphs import CudaGraphManager
# If full scope, just cudagraph the entire layer
if not self.config.cuda_graph_scope:
self.cudagraph_manager = CudaGraphManager(config)
elif (
CudaGraphScope.attn in self.config.cuda_graph_scope
and self.submodules_config.self_attention != IdentityOp
):
self.cudagraph_manager = CudaGraphManager(config)
elif (
CudaGraphScope.mlp in self.config.cuda_graph_scope
and self.submodules_config.mlp != IdentityOp
):
# Cudagraphing MoE layers are supposed handled by MoeTransforerLayer
assert not self.is_moe_layer
self.cudagraph_manager = CudaGraphManager(config)
@staticmethod
def _get_layer_offset(config: TransformerConfig):
"""
Get the layer offset for the current pipeline stage.
Deprecated: please use `get_transformer_layer_offset` instead.
"""
warnings.warn(
"TransformerLayer._get_layer_offset is deprecated."
"Please use get_transformer_layer_offset instead."
)
return get_transformer_layer_offset(config)
def _forward_attention(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
rotary_pos_cos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Perform a forward pass through the attention layer and the layernorms before and after
the attention operations.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine.
Currently used exclusively for inference with dynamic batching and flashinfer RoPE.
attention_bias (Tensor, optional): Bias tensor for Q * K.T.
inference_context (object, optional): Parameters for inference-time optimizations.
packed_seq_params (object, optional): Parameters for packed sequence processing.
sequence_len_offset (Tensor, optional): Offset along sequence dimension
during inference.
Returns:
Tuple[Tensor, Tensor]: A tuple containing:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.
context (Tensor): Updated context tensor if cross-attention is used,
otherwise None.
"""
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
inference_context = deprecate_inference_params(inference_context, inference_params)
# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states:
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
apply_module(self.input_layernorm), hidden_states
)
else:
with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states:
input_layernorm_output = apply_module(self.input_layernorm)(hidden_states)
if isinstance(input_layernorm_output, tuple):
if len(input_layernorm_output) != 2:
raise ValueError(
f"When the output of input_layernorm is a tuple, it is "
f"expected to have 2 elements (output, residual), but "
f"got {len(input_layernorm_output)}"
)
input_layernorm_output, residual = input_layernorm_output
else:
residual = hidden_states
if self.config.fp32_residual_connection:
residual = residual.float()
using_fused_tp_inference_kernel = (not self.training) and (
self.config.inference_fuse_tp_communication
)
if using_fused_tp_inference_kernel:
# Set the residual for fused reduce-scatter + add + layer-norm + all-gather
# operation in attention's out_proj (linear_proj)
self._set_proj_residual(residual)
# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
rotary_pos_cos_sin=rotary_pos_cos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
if using_fused_tp_inference_kernel:
# In inference optimized transformer layer, there is no bias and dropout
# The remaining residual add is already handled inside the
# self attention module.
hidden_states = attention_output_with_bias[0]
else:
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")
# Delay the offload of the attention norm until after the self_attn_bda has been computed
# because the residual is needed in the self_attn_bda.
if self.offload_attn_norm:
hidden_states = off_interface.group_commit(
hidden_states, name="attn_norm", forced_released_tensors=[residual]
)
# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = apply_module(self.pre_cross_attn_layernorm)(hidden_states)
if isinstance(pre_cross_attn_layernorm_output, tuple):
if len(pre_cross_attn_layernorm_output) != 2:
raise ValueError(
f"When the output of pre_cross_attn_layernorm_output "
f"is a tuple, it is expected to have 2 elements "
f"(output, residual), but "
f"got {len(pre_cross_attn_layernorm_output)}"
)
pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output
else:
residual = hidden_states
if self.config.fp32_residual_connection:
residual = residual.float()
# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_context=inference_context,
)
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
return hidden_states, context
@copy_signature(_forward_attention)
def forward(self, *args, **kwargs):
"""
Perform a forward pass through the transformer layer.
This method calls the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
"""
hidden_states, context = self._forward_attention(*args, **kwargs)
output = self._forward_mlp(
hidden_states,
kwargs.get("inference_context", None),
padding_mask=kwargs.get("padding_mask", None),
)
return output, context
def _forward_pre_mlp_layernorm(self, hidden_states: Tensor):
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states:
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
apply_module(self.pre_mlp_layernorm), hidden_states
)
else:
with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states:
pre_mlp_layernorm_output = apply_module(self.pre_mlp_layernorm)(hidden_states)
return pre_mlp_layernorm_output
def _forward_mlp(
self,
hidden_states: Tensor,
inference_context: BaseInferenceContext | None = None,
padding_mask: Tensor | None = None,
) -> Tensor | list[Tensor | None]:
"""
Perform a forward pass through the feed-forward layer.
Args:
hidden_states (Tensor): Transformed hidden states before the MLP layernorm.
Shape [seq_length, batch_size, hidden_size].
inference_context: Inference context for optimizations.
padding_mask (Tensor, optional): Padding mask for MoE routing.
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
Only used for MoE layers to exclude padding tokens from aux loss computations.
The MoELayer will internally transform this to [seq_length, bsz] format.
Returns:
output (Tensor): Transformed hidden states of shape [s, b, h].
"""
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states)
if isinstance(pre_mlp_layernorm_output, tuple):
if len(pre_mlp_layernorm_output) != 2:
raise ValueError(
f"When the output of pre_mlp_layernorm is a tuple, it is "
f"expected to have 2 elements (output, residual), but "
f"got {len(pre_mlp_layernorm_output)}"
)
pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output
else:
# Residual connection.
residual = hidden_states
if self.config.fp32_residual_connection:
residual = residual.float()
nvtx_range_push(suffix="mlp")
# Potentially chunk the MLP computation during prefill to minimize the peak activation size
should_chunk_mlp_for_prefill = (
self.config.mlp_chunks_for_prefill > 1
and inference_context is not None
and not inference_context.is_decode_only()
and not isinstance(self.mlp, IdentityOp)
and not self.config.transformer_impl == "inference_optimized"
)
using_fused_tp_inference_kernel = (not self.training) and (
self.config.inference_fuse_tp_communication
)
if self.recompute_mlp:
if self.config.fp8 or self.config.fp4:
# import here to avoid circular import
from megatron.core.extensions.transformer_engine import te_checkpoint
mlp_output_with_bias = te_checkpoint(
self.mlp,
False,
tensor_parallel.random.get_cuda_rng_tracker,
self.pg_collection.tp,
pre_mlp_layernorm_output,
padding_mask=padding_mask,
)
else:
mlp_output_with_bias = tensor_parallel.checkpoint(
functools.partial(self.mlp, padding_mask=padding_mask),
False,
pre_mlp_layernorm_output,
)
elif should_chunk_mlp_for_prefill:
# Chunk input along sequence dimension
num_chunks = min(self.config.mlp_chunks_for_prefill, pre_mlp_layernorm_output.shape[0])
chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0)
# Compute outputs for each chunk
outputs = [self.mlp(chunk) for chunk in chunks]
# Aggregate chunk outputs
mlp_output = torch.cat([out for out, _ in outputs], dim=0)
bias_chunks = [bias for _, bias in outputs if bias is not None]
bias_output = torch.stack(bias_chunks, dim=0).sum(dim=0) if bias_chunks else None
mlp_output_with_bias = (mlp_output, bias_output)
else:
if using_fused_tp_inference_kernel:
# Set the residual for fused reduce-scatter + add + layer-norm + all-gather
# operation in MLP's fc2.
self._set_fc2_residual(residual)
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask)
nvtx_range_pop(suffix="mlp")
if (
self.is_moe_layer
and self.config.cuda_graph_impl == "transformer_engine"
and self.training
and is_graph_capturing()
and CudaGraphScope.moe_router in self.config.cuda_graph_scope
):
if self.recompute_pre_mlp_layernorm:
# Register the recompute hooks to all the cudagraph output tensors, because some
# tensors are in parallel execution paths and they all need pre_mlp_layernorm to be
# recomputed in backward pass. For example, the router path and the shared expert
# path. So only register in one path is risky.
for tensor in mlp_output_with_bias:
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(tensor)
return list(mlp_output_with_bias) + [residual]
else:
return self._forward_post_mlp(mlp_output_with_bias, residual)
def _forward_post_mlp(
self, mlp_output_with_bias: tuple[Tensor, Tensor | None], residual: Tensor
) -> Tensor:
"""
Perform operations after the MLP computation.
Args:
mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias.
residual (Tensor): Residual tensor.
Returns:
output (Tensor): Transformed hidden states of shape [s, b, h].
"""
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
using_fused_tp_inference_kernel = (not self.training) and (
self.config.inference_fuse_tp_communication
)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
if using_fused_tp_inference_kernel:
# In inference optimized transformer layer, there is no bias and dropout
# The remaining residual add is already handled inside the
# MLP module.
hidden_states = mlp_output_with_bias[0]
else:
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")
# Delay the offload of the mlp norm until after the mlp_bda has been computed
# because the residual is needed in the mlp_bda.
if self.offload_mlp_norm:
hidden_states = off_interface.group_commit(
hidden_states, name="mlp_norm", forced_released_tensors=[residual]
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the transformer layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the transformer layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
def configure_fused_tp_inference(
self,
skip_qkv_norm_and_all_gather: bool = False,
fc2_next_layer_norm_weights: Optional[Tensor] = None,
):
"""
Configure settings for fused TP communication in inference mode.
Args:
skip_qkv_norm (bool): Whether to skip norm and all-gather for linear_qkv.
fc2_next_layer_norm_weights (Optional[Tensor]): Next layer's QKV norm weights
for current layer's MLP FC2.
"""
self.self_attention.linear_qkv.skip_norm_and_all_gather = skip_qkv_norm_and_all_gather
# Use current layer's own MLP FC1 norm weights for attention's/mixer's out_proj
mlp_fc1_weights = self.get_mlp_layer_norm_weights()
self._set_proj_next_layer_norm_weights(mlp_fc1_weights)
self.mlp.linear_fc1.skip_norm_and_all_gather = True
# Use next layer's attention norm weights for current layer's MLP FC2
self._set_fc2_next_layer_norm_weights(fc2_next_layer_norm_weights)
def _set_proj_next_layer_norm_weights(self, weights: Tensor):
"""Set next layer norm weights for attention/mixer's linear_proj."""
self.self_attention.linear_proj._set_next_layer_norm_weights(weights)
def _set_fc2_next_layer_norm_weights(self, weights: Optional[Tensor]):
"""Set next layer norm weights for MLP FC2."""
if weights is None:
# Create dummy tensor for last layer (same shape as fc1 norm weights)
weights = torch.empty_like(self.get_mlp_layer_norm_weights())
self.mlp.linear_fc2._set_next_layer_norm_weights(weights)
def _set_proj_residual(self, residual: Tensor):
"""Set residual for attention's/mixer's out_proj (linear_proj)."""
self.self_attention.linear_proj._set_residual(residual)
def _set_fc2_residual(self, residual: Tensor):
"""Set residual for MLP FC2."""
self.mlp.linear_fc2._set_residual(residual)
def get_mlp_layer_norm_weights(self) -> Tensor:
"""
Get the MLP FC1 layer norm weights.
Returns:
Tensor: The layer norm weight data.
"""
return self.mlp.linear_fc1.layer_norm_weight.data
def get_qkv_layer_norm_weights(self) -> Tensor:
"""
Get the QKV layer norm weights.
Returns:
Tensor: The layer norm weight data.
"""
return self.self_attention.linear_qkv.layer_norm_weight.data
def get_layer_static_inputs(self, seq_length, micro_batch_size):
"""
Get the static inputs for the transformer layer. Besides the hidden_states that is
generated in GraphableMegatronModule, we also add the attention_mask.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the static inputs for the layer.
"""
static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size)
if not isinstance(self.self_attention, IdentityOp) and (
not self.config.cuda_graph_scope or CudaGraphScope.attn in self.config.cuda_graph_scope
):
slen_per_cp = seq_length // self.config.context_parallel_size
static_inputs["attention_mask"] = (
~(torch.tril(torch.ones((slen_per_cp, seq_length))).bool())
.to(torch.cuda.current_device())