forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer_engine.py
More file actions
2761 lines (2415 loc) · 112 KB
/
transformer_engine.py
File metadata and controls
2761 lines (2415 loc) · 112 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import dataclasses
import enum
import inspect
import io
import os
import pickle
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, cast
import torch
import torch.nn.functional as F
from packaging.version import Version as PkgVersion
from torch import Tensor
from torch.nn.parameter import Parameter
from typing_extensions import override
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp4Recipe, Fp8Recipe
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_amax_reduction_group,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
model_parallel_is_initialized,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.quant_config import QuantizationConfig
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
get_expert_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.torch_norm import LayerNormInterface
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
ensure_metadata_has_dp_cp_group,
is_layer_window_attention,
make_sharded_tensors_for_checkpoint,
)
from megatron.core.typed_torch import copy_signature
from megatron.core.utils import (
get_pg_rank,
get_pg_size,
get_te_version,
get_tensor_model_parallel_group_if_none,
is_te_min_version,
is_torch_min_version,
)
try:
import transformer_engine as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
HAVE_TE = True
except ImportError:
if TYPE_CHECKING:
# For type checking, treat transformer_engine as always available.
import transformer_engine as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
HAVE_TE = True
else:
from unittest.mock import MagicMock
te = MagicMock()
HAVE_TE = False
_TE_CONFIG_TYPE_KEY = "transformer_engine_config_type"
class TransformerEngineConfigType(enum.Enum):
"""Configuration object types in config dictionary"""
TEQuantizationParams = "TEQuantizationParams"
@dataclasses.dataclass
class TEQuantizationRecipe:
"""Class to capture options for opening an autocast context in forward"""
fp8_quantization_recipe: Optional[Fp8Recipe] = None
"""
An FP8 quantization override if the module should use FP8.
If no FP8 or FP4 quantization is configured, the recipe is execution
in high-precision (BF16).
"""
fp4_quantization_recipe: Optional[Fp4Recipe] = None
"""
An FP4 quantization override if the module should use FP4.
If no FP8 or FP4 quantization is configured, the recipe is execution
in high-precision (BF16).
"""
custom_recipe_factory: Optional[str] = None
"""The path to a custom recipe factory if a custom Fp4 or Fp8 recipe is configured"""
fp8_format: str = "e4m3"
"""A format to select from an FP8Recipe"""
override_quantized_autocast: bool = True
"""
If the quantization autocast context for a targeted module is enabled,
whether to override it and change (or disable) the quantization recipe.
"""
override_nonquantized_autocast: bool = False
"""
If the quantization autocast context for a targeted module is not enabled,
whether to override it and enable a quantization recipe.
"""
tp_only_amax_red: bool = False
"""
If an amax reduction is applicable, such as in per-tensor quantization recipe,
whether to reduce only along TP groups.
"""
@classmethod
def parse_from_config(cls, quant_config: Dict[Any, Any]) -> "TEQuantizationRecipe":
"""
Parse config from quantization dictionary.
"""
kwargs = {}
class_keys = cls.get_config_keys()
for field in class_keys:
if field in quant_config:
kwargs[field] = quant_config[field]
for field in quant_config:
if field not in class_keys:
raise ValueError(f"Field '{field}' not valid for this configuration.")
instance = TEQuantizationRecipe(**kwargs)
if instance.fp8_quantization_recipe == Fp8Recipe.delayed:
raise ValueError("Delayed scaling not in scope of te per-module quantization config.")
if (
instance.fp8_quantization_recipe is not None
and instance.fp4_quantization_recipe is not None
):
raise ValueError("fp8 and fp4 quantization settings are mutually exclusive.")
if (
instance.fp8_quantization_recipe == Fp8Recipe.custom
or instance.fp4_quantization_recipe == Fp4Recipe.custom
):
if instance.custom_recipe_factory is None:
raise ValueError("custom fp8 or fp4 recipe requires custom_recipe_factory")
return instance
@classmethod
def get_config_keys(cls) -> Set[str]:
"""Get expected keys from the dataclass fields."""
return {field.name for field in dataclasses.fields(cls)}
@dataclasses.dataclass
class TEQuantizationParams:
"""Class to capture precision options for training and evaluation."""
training_recipe: TEQuantizationRecipe
"""Precision override for when self.training is True"""
evaluation_recipe: Optional[TEQuantizationRecipe]
"""
Precision override for when self.training is False.
If None, training_recipe is used.
"""
@staticmethod
def parse_from_config(quant_config: QuantizationConfig) -> "TEQuantizationParams":
"""Parses quantization config for a layer or throw an error."""
config = quant_config.config
try:
config_type = TransformerEngineConfigType(config[_TE_CONFIG_TYPE_KEY])
except KeyError:
raise ValueError(
f"TransformerEngine config dictionary must have '{_TE_CONFIG_TYPE_KEY}' key."
)
except ValueError:
raise ValueError(f"Unsupported config type '{config[_TE_CONFIG_TYPE_KEY]}'.")
if config_type == TransformerEngineConfigType.TEQuantizationParams:
if 'training_recipe' not in config.keys():
raise ValueError(
"TransformerEngine config dictionary must have 'training_recipe' key"
)
training_recipe = TEQuantizationRecipe.parse_from_config(config['training_recipe'])
if 'evaluation_recipe' not in config.keys():
evaluation_recipe = None
assert len(config.keys()) == 2
else:
evaluation_recipe = TEQuantizationRecipe.parse_from_config(
config['evaluation_recipe']
)
assert len(config.keys()) == 3
return TEQuantizationParams(
training_recipe=training_recipe, evaluation_recipe=evaluation_recipe
)
else:
raise NotImplementedError(f"Unhandled configuration type {config_type}")
def _get_fp8_autocast_for_quant_recipe(qrecipe: TEQuantizationRecipe):
if FP8GlobalStateManager.is_fp8_enabled():
if not qrecipe.override_quantized_autocast:
return nullcontext()
else:
if not qrecipe.override_nonquantized_autocast:
return nullcontext()
if qrecipe.fp8_quantization_recipe is None and qrecipe.fp4_quantization_recipe is None:
# Force BF16 for this layer and override autocast
return fp8_autocast(enabled=False)
else:
amax_group = None
if model_parallel_is_initialized():
amax_group = get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=qrecipe.tp_only_amax_red
)
if (
qrecipe.fp8_quantization_recipe == Fp8Recipe.custom
or qrecipe.fp4_quantization_recipe == Fp4Recipe.custom
):
from megatron.core.fp8_utils import _get_custom_recipe
assert qrecipe.custom_recipe_factory is not None
quant_recipe = _get_custom_recipe(qrecipe.custom_recipe_factory)
elif qrecipe.fp8_quantization_recipe is not None:
if qrecipe.fp8_format == "e4m3":
fp8_format = te.common.recipe.Format.E4M3
elif qrecipe.fp8_format == "hybrid":
fp8_format = te.common.recipe.Format.HYBRID
else:
raise ValueError(f"Unhandled fp8_format {qrecipe.fp8_format}")
if qrecipe.fp8_quantization_recipe == Fp8Recipe.tensorwise:
quant_recipe = te.common.recipe.Float8CurrentScaling(fp8_format=fp8_format)
elif qrecipe.fp8_quantization_recipe == Fp8Recipe.blockwise:
quant_recipe = te.common.recipe.Float8BlockScaling(fp8_format=fp8_format)
elif qrecipe.fp8_quantization_recipe == Fp8Recipe.mxfp8:
quant_recipe = te.common.recipe.MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unhandled fp8 recipe: {qrecipe.fp8_quantization_recipe}")
else:
# Fp4 configured.
if qrecipe.fp4_quantization_recipe == Fp4Recipe.nvfp4:
quant_recipe = te.common.recipe.NVFP4BlockScaling()
else:
raise ValueError(f"Unhandled fp4 recipe: {qrecipe.fp8_quantization_recipe}")
return fp8_autocast(enabled=True, fp8_recipe=quant_recipe, fp8_group=amax_group)
def _get_fp8_autocast_for_quant_params(qparams: TEQuantizationParams | None, training: bool):
if qparams is None:
return nullcontext()
elif not training and qparams.evaluation_recipe is not None:
return _get_fp8_autocast_for_quant_recipe(qparams.evaluation_recipe)
else:
return _get_fp8_autocast_for_quant_recipe(qparams.training_recipe)
def _get_should_context_be_quantized_recipe(
qrecipe: TEQuantizationRecipe, is_original_context_quantized: bool
):
if is_original_context_quantized:
if not qrecipe.override_quantized_autocast:
return is_original_context_quantized
else:
if not qrecipe.override_nonquantized_autocast:
return is_original_context_quantized
if qrecipe.fp8_quantization_recipe is None and qrecipe.fp4_quantization_recipe is None:
# Force BF16 for this layer and override autocast
return False
else:
return True
def _get_should_context_be_quantized_params(
qparams: TEQuantizationParams | None, training: bool, is_context_quantized: bool
):
if qparams is None:
return is_context_quantized
elif not training and qparams.evaluation_recipe is not None:
return _get_should_context_be_quantized_recipe(
qparams.evaluation_recipe, is_context_quantized
)
else:
return _get_should_context_be_quantized_recipe(
qparams.training_recipe, is_context_quantized
)
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
if is_te_min_version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = "cpu"
elif config.init_model_with_meta_device:
extra_transformer_engine_kwargs["device"] = "meta"
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)
def split_te_layernorm_column_parallel_linear(
fused_layer,
config,
init_method: Optional[callable] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Split a TELayerNormColumnParallelLinear into separate TENorm and TEColumnParallelLinear layers.
Args:
fused_layer: The fused TELayerNormColumnParallelLinear layer to split
config: TransformerConfig to use for creating the new layers
init_method: Initialization method for the linear layer (optional)
tp_group: Tensor parallel group (optional)
Returns:
A tuple of (TENorm, TEColumnParallelLinear) with weights copied from the fused layer
"""
# Extract dimensions from the fused layer
in_features = fused_layer.in_features
out_features = fused_layer.out_features * fused_layer.tp_size
# Create the norm layer
norm_layer = TENorm(config=config, hidden_size=in_features, eps=fused_layer.eps)
with torch.no_grad():
# Copy layer norm weight
norm_layer.weight.copy_(fused_layer.layer_norm_weight)
# Copy layer norm bias if it exists
if hasattr(norm_layer, 'bias') and hasattr(fused_layer, 'layer_norm_bias'):
if fused_layer.layer_norm_bias is not None:
norm_layer.bias.copy_(fused_layer.layer_norm_bias)
# Create the column parallel linear layer
linear_layer = TEColumnParallelLinear(
input_size=in_features,
output_size=out_features,
config=config,
init_method=init_method or (lambda x: None), # Dummy init since we'll copy weights
gather_output=False,
bias=fused_layer.use_bias,
skip_bias_add=fused_layer.te_return_bias,
is_expert=False,
tp_comm_buffer_name=fused_layer.ub_name,
tp_group=tp_group or fused_layer.tp_group,
)
with torch.no_grad():
# Copy weight
linear_layer.weight.copy_(fused_layer.weight)
# Copy bias if it exists
if fused_layer.use_bias and hasattr(fused_layer, 'bias'):
linear_layer.bias.copy_(fused_layer.bias)
# TODO(Peter): Do we need this
# Copy FP8 metadata if applicable
if hasattr(fused_layer, 'fp8_meta') and fused_layer.fp8_meta is not None:
if hasattr(linear_layer, 'fp8_meta'):
# Copy FP8 scaling factors and other metadata
for key in fused_layer.fp8_meta:
if key in linear_layer.fp8_meta:
if isinstance(fused_layer.fp8_meta[key], dict):
for subkey in fused_layer.fp8_meta[key]:
if subkey in linear_layer.fp8_meta[key]:
linear_layer.fp8_meta[key][subkey] = fused_layer.fp8_meta[key][
subkey
]
else:
linear_layer.fp8_meta[key] = fused_layer.fp8_meta[key]
# Set the same configuration flags
linear_layer.sequence_parallel = fused_layer.sequence_parallel
linear_layer.is_first_microbatch = fused_layer.is_first_microbatch
linear_layer.disable_parameter_transpose_cache = fused_layer.disable_parameter_transpose_cache
return norm_layer, linear_layer
if HAVE_TE and is_te_min_version("1.13.0"):
class TEActivationOp:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's activation
function operators (e.g. Silu, SwiGLU, etc)
"""
def __new__(cls, config: TransformerConfig):
layer_type = None
if config.gated_linear_unit:
if config.activation_func == F.silu:
layer_type = te.pytorch.ops.SwiGLU
elif config.activation_func == F.gelu:
layer_type = te.pytorch.ops.GEGLU
elif config.activation_func == F.silu:
layer_type = te.pytorch.ops.ReGLU
else:
if config.activation_func == F.gelu:
layer_type = te.pytorch.ops.GELU
elif config.activation_func == F.silu:
layer_type = te.pytorch.ops.ReLU
if layer_type is None:
raise Exception(
'Only SwiGLU, GEGLU, ReGLU, GELU, ReLU are supported by '
'transformer engine. Please set use_te_activation_func=False'
)
activation_func_kwargs = {}
if config.activation_func_fp8_input_store:
activation_func_kwargs["cache_quantized_input"] = True
layer = layer_type(**activation_func_kwargs)
return layer
else:
TEActivationOp = None
if HAVE_TE and is_te_min_version("1.13.0"):
class TEFusedResidualRMSNorm(te.pytorch.RMSNorm):
"""
RMSNorm with fused residual output for Megatron Core.
Inherits from te.pytorch.RMSNorm to maintain all parameter management,
checkpoint compatibility, and Megatron-specific features. Creates a fused
implementation using TE's ops API that shares the base class parameters.
The fused implementation uses:
- MakeExtraOutput: Forks the residual connection
- RMSNorm: Normalizes the main path
Forward pass returns: (normalized_output, residual)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Fused implementation (stored in tuple to avoid submodule registration)
self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None
def _make_fused_impl(self) -> te.pytorch.ops.Sequential:
"""
Construct fused ops pipeline that shares parameters with base RMSNorm.
Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares
the weight parameter with self.weight from the base class.
"""
fused_impl = te.pytorch.ops.Sequential()
# Op 1: MakeExtraOutput - forks the residual
fused_impl.append(te.pytorch.ops.MakeExtraOutput())
# Op 2: RMSNorm - shares weight parameter with self
kwargs = {
"eps": self.eps,
"device": "meta", # Already initialized
"dtype": self.weight.dtype,
"zero_centered_gamma": self.zero_centered_gamma,
}
# Add sm_margin if available (TE 2.5+)
if hasattr(self, '_sm_margins'):
kwargs["sm_margin"] = self._sm_margins
rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs)
rmsnorm_op.weight = self.weight
fused_impl.append(rmsnorm_op)
self._register_hooks_on_fused_impl(fused_impl)
return fused_impl
def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:
forward_pre_hooks = []
forward_post_hooks = []
backward_pre_hooks = []
backward_post_hooks = []
for submodule in self.modules():
for hook in submodule._forward_pre_hooks.values():
forward_pre_hooks.append((submodule, hook))
for hook in submodule._forward_hooks.values():
forward_post_hooks.append((submodule, hook))
for hook in submodule._backward_pre_hooks.values():
backward_pre_hooks.append((submodule, hook))
for hook in submodule._backward_hooks.values():
backward_post_hooks.append((submodule, hook))
# Pre-forward hooks
# Note: DDP pre-forward hooks are safe since they do not
# interact with input tensor.
if forward_pre_hooks:
from megatron.core.distributed import distributed_data_parallel
if any(
inspect.getmodule(hook) != distributed_data_parallel
for _, hook in forward_pre_hooks
):
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input tensor."
)
def forward_pre_hook(module, *_) -> None:
for submodule, hook in forward_pre_hooks:
# Assume that hook does not interact with input
ret = hook(submodule, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"pre-forward hook that modifies input tensor."
)
fused_impl.register_forward_pre_hook(forward_pre_hook)
# Post-forward hooks
if forward_post_hooks:
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a post-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input or output tensors."
)
def forward_post_hook(module, *_) -> None:
for submodule, hook in forward_post_hooks:
# Assume that hook does not interact with input or output
ret = hook(submodule, None, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"post-forward hook that modifies output tensor."
)
fused_impl.register_forward_hook(forward_post_hook)
# Backward hooks
if backward_pre_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with pre-backward hooks"
)
if backward_post_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with post-backward hooks"
)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with fused residual output.
Args:
hidden_states: Input tensor [s, b, h]
Returns:
Tuple of (normalized_output, residual), both [s, b, h]
Note:
Sequential.forward() automatically returns (output, extra_outputs...)
when MakeExtraOutput is present, so we don't need manual unpacking.
"""
# Construct fused impl lazily on first forward
# (in case parameters are modified after __init__)
if self._fused_impl is None:
self._fused_impl = (self._make_fused_impl(),)
# Apply fused implementation
# Sequential returns (normalized_output, residual) automatically
return self._fused_impl[0](hidden_states)
else:
TEFusedResidualRMSNorm = None # type: ignore[assignment, misc]
class TENorm:
"""A conditional wrapper to initialize an instance of
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.
Residual fusion is a two-level opt-in mechanism:
1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature)
2. Local intent: has_residual=True must be passed at build site (declares this specific
norm is followed by a residual connection)
Fusion only happens when BOTH conditions are met.
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
has_residual: bool = False,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
"Please install it with `pip install transformer-engine`."
)
use_fused_residual = config.fused_residual_rmsnorm and has_residual
if use_fused_residual and config.normalization != "RMSNorm":
raise ValueError("Fused residual is only supported " "for RMSNorm normalization")
if config.normalization == "LayerNorm":
norm_module = te.pytorch.LayerNorm
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
if use_fused_residual:
assert (
TEFusedResidualRMSNorm is not None
), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0"
norm_module = TEFusedResidualRMSNorm
else:
norm_module = te.pytorch.RMSNorm
else:
raise Exception("Only LayerNorm and RMSNorm are currently supported")
instance = norm_module(
normalized_shape=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
return cast(LayerNormInterface, instance)
class TELinear(te.pytorch.Linear):
"""Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
symmetric_ar_type: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
"Please install it with `pip install transformer-engine`."
)
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
self.symmetric_ar_type = symmetric_ar_type
if skip_weight_param_allocation:
raise ValueError(
"Transformer Engine linear layers do not support skip_weight_param_allocation"
)
extra_kwargs = _get_extra_te_kwargs(config)
if self.config.delay_wgrad_compute:
if is_te_min_version("2.3.0"):
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
else:
raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.")
if (
self.config.tp_comm_overlap
and tp_comm_buffer_name
and tp_comm_buffer_name not in ["qkv", "proj", "fc1", "fc2"]
):
self.config.tp_comm_overlap = False
warnings.warn(
f"The user buffer name {tp_comm_buffer_name} is not supported in"
"Transformer Engine. Disabling TP communication overlap "
"for this layer."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
if symmetric_ar_type is not None:
assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher"
assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion(
"2.3.0.dev0+39c0e70"
), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"
extra_kwargs["symmetric_ar_type"] = symmetric_ar_type
if parallel_mode == "duplicated":
assert tp_group is None, "duplicated linear should not have tp_group set"
tp_size = 1
else:
tp_size = get_pg_size(tp_group)
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
if parallel_mode == "duplicated":
rng_tracker_name = get_data_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name
te_parallel_mode = parallel_mode
tp_group_for_te = tp_group
if parallel_mode == "duplicated":
# Handle non-parallel case
tp_group_for_te = None
tp_size = 1
explicit_expert_comm = False
te_parallel_mode = None
else:
# Disable communications in TE when using TP or EP by
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
te_parallel_mode = None
tp_size = 1
tp_group_for_te = None
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
# Pass None if not initialized for backward compatibility with the ckpt converter.
tp_group=tp_group_for_te if torch.distributed.is_initialized() else None,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=te_parallel_mode,
**extra_kwargs,
)
self.te_quant_params: Optional[TEQuantizationParams] = None
for param in self.parameters():
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
else:
# Reduce the gradient on DP group
setattr(param, "allreduce", True)
if parallel_mode == "duplicated":
# Reduce the gradient further on the TP group since the weight is
# duplicated across TP ranks
setattr(param, "sequence_parallel", self.config.sequence_parallel)
# Mark as NOT tensor parallel since weight is duplicated
setattr(param, "tensor_model_parallel", False)
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
self._tp_group = tp_group
def finish_init(self, quantization_config: QuantizationConfig):
"""Post-init of quantization override"""
if quantization_config is None:
self.te_quant_params = None
else:
self.te_quant_params = TEQuantizationParams.parse_from_config(quantization_config)
def will_execute_quantized(self, is_context_quantized: bool) -> bool:
"""Returns whether the module is configured to execute quantized."""
return _get_should_context_be_quantized_params(
self.te_quant_params, self.training, is_context_quantized
)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
quant_context = _get_fp8_autocast_for_quant_params(self.te_quant_params, self.training)
with quant_context:
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Replicate cross TP/DP."""
# Provide the dist-ckpt support when TELinear is directly used
# It can only happen with duplicated parallel mode
assert (
self.parallel_mode is None
), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict,
prefix,
None,
sharded_offsets,
tp_group=self._tp_group,
dp_cp_group=metadata["dp_cp_group"],
)
def backward_dw(self):
"""Compute weight gradients during the backward pass if delay_wgrad_compute is enabled."""
if self.config.delay_wgrad_compute:
super().backward_dw()
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""Wrapper for the Transformer-Engine's `LayerNormLinear` layer
that combines layernorm and linear layers."""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
stride: int = 1,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
"Please install it with `pip install transformer-engine`."
)
self.config = config
if gather_output:
raise ValueError("Transformer Engine linear layers do not support gather_output = True")
if is_expert:
raise ValueError("Transformer Engine linear layers do not yet support MoE")
if skip_weight_param_allocation:
raise ValueError(
"Transformer Engine linear layers do not support skip_weight_param_allocation"
)
# TODO: For backward compatibility, remove in v0.15.
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
self._tp_group = tp_group
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
self.tp_size = get_pg_size(tp_group)
self.tp_rank = get_pg_rank(tp_group)
if self.config.delay_wgrad_compute:
if is_te_min_version("2.3.0"):
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
else:
raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.")
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
te_version = get_te_version()
raise ValueError(
f"Transformer Engine v{te_version} does not support {self.config.normalization}."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if is_te_min_version("1.5.0", check_equality=False):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if is_te_min_version("1.6.0.dev0", check_equality=False):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
if tp_comm_buffer_name == "qkv" and self.config.tp_comm_overlap_disable_qkv:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
if tp_comm_buffer_name == "fc1" and self.config.tp_comm_overlap_disable_fc1:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag