forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer_config.py
More file actions
2287 lines (1937 loc) · 107 KB
/
transformer_config.py
File metadata and controls
2287 lines (1937 loc) · 107 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
import math
import warnings
from dataclasses import dataclass, field
from typing import Callable, List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from megatron.core.enums import Fp4Recipe, Fp8Recipe
from megatron.core.quantization.quant_config import RecipeConfig
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout
from .._rank_utils import log_single_rank
from ..fusions.fused_bias_geglu import quick_gelu
from ..model_parallel_config import ModelParallelConfig
from ..utils import (
get_te_version,
init_method_normal,
is_te_min_version,
is_torch_min_version,
mup_scaled_init_method_normal,
scaled_init_method_normal,
)
logger = logging.getLogger(__name__)
try:
from packaging.version import Version as PkgVersion
HAVE_PACKAGING = True
except ImportError:
HAVE_PACKAGING = False
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter,
including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers: int = field(default=0, metadata={"argparse_meta": {"default": None}})
"""Number of transformer layers in a transformer block."""
mtp_num_layers: Optional[int] = None
"""Number of Multi-Token Prediction (MTP) Layers.
MTP extends the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens
by using D sequential modules to predict D additional tokens.
"""
mtp_loss_scaling_factor: Optional[float] = 0.1
"""Weighting factor of Multi-Token Prediction (MTP) loss.
We compute the average of the MTP losses across all depths,
and multiply it the scaling factor to obtain the overall MTP loss,
which serves as an additional training objective.
"""
mtp_use_repeated_layer: bool = False
"""Use a single MTP layer repeatedly instead of multiple separate layers."""
mtp_hybrid_override_pattern: Optional[str] = None
"""DEPRECATED: Use unified hybrid_layer_pattern instead.
Legacy argument for loading old checkpoints.
Force a specific hybrid layer pattern for MTP layers.
"""
num_layers_in_first_pipeline_stage: Optional[int] = None
"""Number of transformer layers on first pipeline stage.
None implies equal layer division across PP ranks."""
num_layers_in_last_pipeline_stage: Optional[int] = None
"""Number of transformer layers on last pipeline stage.
None implies equal layer division across PP ranks."""
pipeline_model_parallel_layout: Optional[Union[str, list, PipelineParallelLayerLayout]] = None
"""Custom definition of the pipeline parallel partitioning.
Support type:
- str: e.g., 'Et*3|(tt|)*29,m|L'. Stages are split by '|', replicated stages or layers
can be described with multiplication. Commas can be used cosmetically.
- list: e.g., [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']].
- PipelineParallelLayerLayout: a PipelineParallelLayerLayout object.
If given either a string or a list, it will be transferred into a PipelineParallelLayerLayout
in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers
in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ...,
vpp(i)pp(j), vpp(i)pp(j+1), ..., vpp(-1)pp(-2), vpp(-1)pp(-1).
In the inner lists of layers, 'embedding' or 'E' denotes the embedding layer, 'loss' or 'L'
denotes the loss function, and 'decoder' or 't' denotes the transformer decoder layer.
Examples:
[['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']]:
pp = 2, vpp = None
pp rank 0 holds: embedding, decoder
pp rank 1 holds: decoder*3, loss
'E|(tt|)*2,(t|)*4,mL':
pp = 2, vpp = 4
vpp rank 0 pp rank 0 holds: embedding
vpp rank 0 pp rank 1~2 holds: decoder*2
vpp rank 0 pp rank 3 holds: decoder
vpp rank 1 pp rank 0~2 holds: decoder
vpp rank 1 pp rank 3 holds: mtp, loss"""
account_for_embedding_in_pipeline_split: bool = False
"""If set, the embedding layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
account_for_loss_in_pipeline_split: bool = False
"""If set, the loss layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
hidden_size: int = field(default=0, metadata={"argparse_meta": {"default": None}})
"""Transformer hidden size."""
num_attention_heads: int = field(default=0, metadata={"argparse_meta": {"default": None}})
"""Number of transformer attention heads."""
attention_backend: AttnBackend = AttnBackend.auto
"""Attention backend to run. By default we let transformer engine
decide the best backend to run (except in the case of local).
If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """
softmax_scale: Optional[float] = None
"""Softmax scale for attention scaling."""
softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla'
"""Applies modified softmax from https://www.evanmiller.org/attention-is-off-by-one.html.
Supports both TE FusedAttention and local unfused attention. Supports both a fixed offset and
and learnable offset."""
num_query_groups: Optional[int] = field(
default=None, metadata={"argparse_meta": {"default": 1}}
)
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: Optional[int] = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
if not provided."""
kv_channels: Optional[int] = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""
attention_dropout: float = 0.1
"""Post attention dropout probability."""
fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon: float = field(
default=1e-5, metadata={"argparse_meta": {"arg_names": ["--norm-epsilon"]}}
)
"""Epsilon value for any LayerNorm/RMSNorm operations."""
layernorm_zero_centered_gamma: bool = field(
default=False, metadata={"argparse_meta": {"arg_names": ["--apply-layernorm-1p"]}}
)
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear: bool = field(
default=True, metadata={"argparse_meta": {"arg_names": ["--disable-bias-linear"]}}
)
"""Include/exclude a bias term in all linear layers (QKV projections, after core attention,
and two in MLP layer)."""
add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""
gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func: Callable[[torch.Tensor], torch.Tensor] = F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store: bool = False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
glu_linear_offset: float = 0.0
"""Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). Only
used when gated_linear_unit is True"""
activation_func_clamp_value: Optional[float] = None
"""Clamp the output of the linear_fc1 in the activation function. Only used when activation_func
is quick_gelu."""
num_moe_experts: Optional[int] = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
window_attn_skip_freq: Optional[Union[int, List[int]]] = None
"""Frequency of full attention layers among sliding window attention layers. Accepts either:
- An integer N: Represents a (N-1):1 ratio, one full attention layer after (N-1) SWA layers.
- A list that defines a custom pattern, e.g.: [1,1,1,1,0,0,0,0], where 1 represents SWA. """
normalization: Literal['LayerNorm', 'RMSNorm'] = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""
qk_l2_norm: bool = False
"""Whether to apply llama 4-style qk L2 norm."""
qk_clip: bool = False
"""Whether to clip the query and key weights. Needed for Muon MLA Model training."""
qk_clip_alpha: float = 0.5
"""The balancing alpha for qk-clip. Q = Q * (eta ** alpha)"""
qk_clip_threshold: float = 100
"""The balancing threshold for qk-clip. eta = min(threshold / max_attention_logits, 1.0)"""
log_max_attention_logit: bool = False
"""Whether to log the max attention logit across whole model. Decoupled from qk_clip,
defualts to False. Setting qk_clip will automatically log the max logit"""
attention_output_gate: bool = False
"""Whether to apply output gate to the attention layers."""
test_mode: bool = False
"""Whether to run real-time tests."""
calculate_per_token_loss: bool = False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
multi_latent_attention: bool = False
"""Whether to use multi-latent attention."""
no_rope_freq: Optional[Union[int, List[int]]] = None
"""Controls which layers perform Rotary Position Embedding (RoPE). Accepts either:
An integer N: Creates a pattern where RoPE is skipped every N-1 layers. For example,
no_rope=4 means RoPE is applied for 3 layers, then skipped for 1 layer, repeating this pattern.
A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE.
For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE."""
####################
# attention variant
####################
experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None
"""Type of attention variant to use. Currently support gated_delta_net and dsa."""
####################
# DSA
####################
dsa_indexer_n_heads: Optional[int] = None
"""Number of DSA indexer heads."""
dsa_indexer_head_dim: Optional[int] = None
"""Dimension per DSA indexer head."""
dsa_indexer_topk: Optional[int] = None
"""Number of top-k tokens to select in DSA indexer."""
dsa_indexer_loss_coeff: Optional[float] = None
"""Coefficient for the DSA indexer KL divergence loss. Set to 0 to disable indexer loss."""
dsa_indexer_use_sparse_loss: bool = False
"""Whether to use sparse DSA indexer loss. If True, the indexer loss will be computed using the
top-k indices."""
####################
# linear attention
####################
linear_attention_freq: Optional[Union[int, List[int]]] = None
"""Frequency between LA (linear attention) layers
and SDPA (scaled dot-product attention) layers.
Accepts either:
- An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer
- A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
linear_conv_kernel_dim: Optional[int] = 4
"""Conv kernel dimension for the gated delta net."""
linear_key_head_dim: Optional[int] = 128
"""Query and key head dimension for the gated delta net."""
linear_value_head_dim: Optional[int] = 128
"""Value and gate head dimension for the gated delta net."""
linear_num_key_heads: Optional[int] = 16
"""Number of query and key heads for the gated delta net."""
linear_num_value_heads: Optional[int] = 32
"""Number of value and gate heads for the gated delta net."""
####################
# initialization
####################
init_method: Optional[Callable] = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Optional[Callable] = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).
Note: this does not control vocab readout/unembedding initialization."""
init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
embedding_init_method: Optional[Callable] = None
"""
Method to initialize weights of the embedding layer. If None, will be set as described
in init_method above.
"""
embedding_init_method_std: Optional[float] = None
"""
Standard deviation of the zero mean normal for the default initialization method for the
embedding layer. If None, will be set to init_method_std. Setting this to a value around
1.0 may avoid loss spikes in training. Setting this to any value will also skip applying
weight decay on embedding weights to avoid shrinkage towards zero.
See https://arxiv.org/abs/2312.16903 for more details.
"""
init_model_with_meta_device: bool = False
"""
If True, initializes the model with the meta device. This is helpful for
training of very large models. This feature is only works when megatron fsdp is turned on.
"""
####################
# MuP (Maximal Update Parameterization)
####################
use_mup: bool = False
"""
Enable Maximal Update Parameterization (MuP) for hyperparameter transfer across
model widths. When enabled, learning rates and initialization are scaled according
to the width multiplier to ensure consistent training dynamics.
"""
mup_width_mult: float = 1.0
"""
Width multiplier for MuP scaling, computed as hidden_size / mup_base_hidden_size.
This value is automatically computed in __post_init__ when use_mup is enabled.
"""
mup_base_hidden_size: Optional[int] = None
"""
Base hidden size for MuP width scaling. This is the reference width from which
scaling factors are computed. Defaults to hidden_size if not specified (base model
case where width_mult=1.0). Set this to your base/proxy model's hidden size when
scaling up.
"""
mup_embedding_mult: float = 1.0
"""
Multiplier for embedding layer output. Applied after the embedding lookup.
Default: 1.0 (no scaling).
"""
mup_output_mult: float = 1.0
"""
Multiplier for output logits before softmax. When MuP is enabled and this is left
at 1.0, it is auto-set to 1/mup_width_mult to keep output variance stable across
widths. Override to customize output scaling.
Default: 1.0.
"""
mup_base_head_dim: Optional[float] = None
"""
Base head dimension for MuP attention scaling. When set,
softmax_scale = sqrt(mup_base_head_dim) / (kv_channels ** mup_attn_scale_power).
Set to base model's d_head (e.g., 64) to match standard 1/sqrt(d_head) scaling
at the base width, ensuring non-MuP compatibility for that specific value.
"""
mup_attn_scale_power: float = 1.0
"""
Power for attention scaling: softmax_scale = 1 / (kv_channels ** mup_attn_scale_power).
0.5 = standard attention (1/sqrt(d_head)), 1.0 = MuP attention (1/d_head).
Default: 1.0 (MuP scaling when use_mup is True). Set to 0.5 for standard scaling.
"""
####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16. Also sets `attention_softmax_in_fp32` to True."""
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
disable_bf16_reduced_precision_matmul: bool = False
"""If True, sets torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction=False to
prevent matmul from using reduced precision accumulation when using BF16."""
####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""
persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""
use_fused_weighted_squared_relu: bool = False
"""If True, uses fused weighted squared relu kernel when using MoE."""
fused_single_qkv_rope: bool = False
"""If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads."""
fused_residual_rmsnorm: bool = False
"""If True, fuses residual connection and RMSNorm backward pass when TE is used."""
####################
# activation recomputation
####################
recompute_granularity: Optional[Literal['full', 'selective']] = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where the submodules set in --recompute-modules is checkpointed.
The default is "core_attn" which is the memory intensive part of attention.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: Optional[Literal['uniform', 'block']] = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: Optional[int] = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: Optional[bool] = False
"""If True, distribute recomputed activations across the model parallel group."""
recompute_modules: Optional[List[str]] = None
"""The submodules to recompute.
choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", "shared_experts".
default: ["core_attn"].
"core_attn": recompute the core attention part of the transformer layer.
"moe_act": recompute the MoE MLP activation function.
"layernorm": recompute the input_layernorm and pre_mlp_layernorm.
"mla_up_proj": recompute the MLA up projection and RoPE applying parts.
"mlp": recompute the dense MLP submodule.
"moe": recompute the MoE layer.
"shared_experts": recompute the shared experts in the MoE layer.
"moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing,
"core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing.
"""
####################
# fp8 related
####################
fp8: Optional[Literal['e4m3', 'hybrid']] = field(
default=None, metadata={"argparse_meta": {"arg_names": ["--fp8-format"]}}
)
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_recipe: Optional[Literal['tensorwise', 'delayed', 'mxfp8', 'blockwise', 'custom']] = (
"delayed"
)
"""If set, enables the use of FP8 precision through Transformer Engine. There are 5 predefined
choices (1) 'tensorwise' uses per tensor current scaling recipe, (2) 'delayed'
uses delayed scaling recipe, 3) 'mxfp8' for Blackwell architecture only,
4) 'blockwise' for blockwise scaling recipe, 5) 'custom' for custom quantization recipe."""
fp8_param: bool = False
"""If set, keep the parameters in fp8 precision to save memory. This option must be used
together with fp8 mode (i.e., TransformerConfig.fp8 is not None). Note that not all parameters
will be converted to fp8; for example, biases will remain unchanged. The parameters affected are
primarily the weights of GEMMs. The specific parameters that will be converted to fp8 are
determined by TE."""
fp8_quantizer_factory: Optional[str] = None
"""Python import path to a callable quantizer factory, e.g., package.module.quantizer_factory.
Required when fp8_recipe is custom."""
fp8_margin: int = 0
"""Margin for the scaling factor computation."""
fp8_interval: int = 1
"""DEPRECATED from TransformerEngine v1.8.0. This flag is ignored.
Controls how often the scaling factor is recomputed.
"""
fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo: Literal['most_recent', 'max'] = "most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation
in higher precision."""
fp8_dot_product_attention: bool = False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention: bool = False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
tp_only_amax_red: bool = False
"""When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain"""
first_last_layers_bf16: bool = False
"""If True, retains first and last N TransformerBlocks in BF16 as opposed to FP8."""
num_layers_at_start_in_bf16: int = 1
"""Number of layers at the start of the model to keep in BF16 precision when
first_last_layers_bf16 is True."""
num_layers_at_end_in_bf16: int = 1
"""Number of layers at the end of the model to keep in BF16 precision when
first_last_layers_bf16 is True."""
use_kitchen: bool = False
"""Use the kitchen extension for transformer quantization."""
use_kitchen_attention: bool = False
"""Use the kitchen extension for attention (instead of TE's attention)."""
kitchen_attention_backend: Literal["sdpa", "fa"] = "sdpa"
"""Which kitchen attention backend to use when use_kitchen_attention=True.
"sdpa" for KitchenDotProductAttention, "fa" for KitchenFlashAttention."""
####################
# fp4 related
####################
fp4: Optional[Literal['e2m1']] = field(
default=None, metadata={"argparse_meta": {"arg_names": ["--fp4-format"]}}
)
"""If set, enables the use of FP4 precision through Transformer Engine. Currently only
supports 'nvfp4' which uses NVFP4BlockScaling recipe (requires TE >= 2.7.0.dev0)."""
fp4_recipe: Optional[Literal['nvfp4', 'custom']] = "nvfp4"
"""If set, enables the use of FP4 precision through Transformer Engine. Currently only
'nvfp4' is supported which uses NVFP4BlockScaling recipe for Blackwell+ architecture."""
fp4_param: bool = field(
default=False, metadata={"argparse_meta": {"arg_names": ["--fp4-param-gather"]}}
)
"""If set, keep the parameters in fp4 precision to save memory. This option must be used
together with fp4 mode (i.e., TransformerConfig.fp4 is not None). Note that not all parameters
will be converted to fp4; for example, biases will remain unchanged."""
fp4_quantizer_factory: Optional[str] = None
"""Python import path to a callable quantizer factory, e.g., package.module.quantizer_factory.
Required when fp4_recipe is custom."""
####################
# MoE related
####################
moe_shared_expert_intermediate_size: Optional[int] = None
"""Shared expert total ffn hidden size.
It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
there are multiple shared experts.
None means no shared expert.
By default, the shared experts execute before the router. However, when
moe_shared_expert_overlap or overlap_moe_expert_parallel_comm is set,
the shared experts execute after the router, before the routed experts.
This makes the gradients from the router and the shared experts added in
different orders to the hidden_states, causing minor numerical differences
in the hidden_states gradient."""
moe_shared_expert_gate: bool = False
"""Enable gate for shared expert. Only effective when
moe-shared-expert-intermediate-size is set."""
moe_shared_expert_overlap: bool = False
"""Enable overlapping between shared expert computations and dispatcher communications.
Without this, the shared experts execute before the router.
Only effective when moe-shared-expert-intermediate-size is set.
"""
moe_layer_freq: Union[int, List[int]] = 1
"""Frequency between MoE layers and Dense layers. Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
- A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
moe_ffn_hidden_size: Optional[int] = None
"""MoE Feed-Forward Network hidden size. If not specified, defaults to the ffn_hidden_size."""
moe_router_load_balancing_type: Union[str, List[str]] = "aux_loss"
"""The load balancing strategy for the router.
Options:
- "aux_loss": Load balancing loss used in GShard and SwitchTransformer, calculated at
micro-batch level.
- "seq_aux_loss": Load balancing loss used in DeepSeekV2 and DeepSeekV3, computes loss
for each individual sample.
- "global_aux_loss": Load balancing loss calculated at global batch level.
- "sinkhorn": Balancing algorithm used in S-BASE.
- "none": No load balancing.
A list of strings can be provided to combine multiple aux-loss load balancing types.
The default is "aux_loss".
"""
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_enable_routing_replay: bool = False
"""If True, enable the routing replay feature for MoE layers."""
moe_router_topk_limited_devices: Optional[int] = None
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""
moe_router_padding_for_quantization: Optional[bool] = False
"""Whether to pad the routing_map to make sure the number of tokens each expert receives
is a multiple of 16/32 for quantized precision (e.g., FP8, FP4). This can remove the explicit
padding in the GroupedMLP layer."""
moe_router_padding_for_fp8: Optional[bool] = False
"""[Compatibility alias for moe_router_padding_for_quantization]
Enabling this will also enable moe_router_padding_for_quantization."""
moe_router_num_groups: Optional[int] = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on sum of
top-('moe_router_topk'/'moe_router_group_topk') routing scores within each group
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""
moe_router_group_topk: Optional[int] = None
"""Number of selected groups for group-limited routing."""
moe_router_pre_softmax: bool = False
"""Enable pre-softmax(pre-sigmoid) routing for MoE, which means softmax is before the
top-k selection.
By default, softmax is done after top-k."""
moe_router_topk_scaling_factor: Optional[float] = None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_router_score_function: Literal['softmax', 'sigmoid'] = "softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_dtype: Optional[Literal['fp32', 'fp64']] = None
"""Data type for routing and expert output weighted averaging. Using fp32 or fp64 can
improve stability especially when the number of experts is large (e.g. finegrained-moe).
None means no changes for dtype."""
moe_router_enable_expert_bias: bool = False
"""TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
The routing decision is based on the sum of the routing scores and the expert bias.
See https://arxiv.org/abs/2408.15664 for details."""
moe_router_bias_update_rate: float = 1e-3
"""The expert bias is updated based on the number of assigned tokens to each expert
in a global batch, where the bias is increased for the experts with less assigned tokens
and decreased for the experts with more assigned tokens.
The default value 1e-3 is same as that used in DeepSeekV3."""
moe_router_force_load_balancing: bool = False
"""[Experimental] Force load balancing with random logits for MoE router, supports naive topk
and group-limited topk. This is an experimental feature and only for benchmark."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_aux_loss_coeff: Union[float, List[float]] = 0.0
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.
If a list of load balancing types is provided for `moe_router_load_balancing_type`,
a corresponding list of coefficients should be provided here."""
moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: Optional[float] = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'flex'] = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'.
Options are 'allgather','alltoall' and 'flex'."""
moe_enable_deepep: bool = False
"""[Experimental] Enable DeepEP for efficient token dispatching and combine in MoE models."""
moe_flex_dispatcher_backend: Literal['deepep', 'hybridep'] = "deepep"
"""[Experimental] The backend to use for flex token dispatcher. The default is "deepep".
Options are "deepep" and "hybridep". Currently only "hybridep" backend supports
the MNNVL case."""
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: Optional[float] = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
will be dropped. The default is None."""
moe_pad_expert_input_to_capacity: bool = False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match
the expert capacity length, effective only after the moe_expert_capacity_factor is set. The
default setting is False."""
moe_pad_experts_for_cuda_graph_inference: bool = False
"""moe_pad_experts_for_cuda_graph_inference (bool): If True, the router will switch to dropping
and padding during decode time which does not have a D2H sync. The capacity factor is set to the
max that an expert could see during inference so no tokens are actually dropped. The default
setting is False."""
moe_token_drop_policy: Literal['probs', 'position'] = "probs"
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with
the lowest probabilities will be dropped. If "position", tokens at the end of each batch will
be dropped.
"""
moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
moe_permute_fusion: bool = False
"""Fuse token rearrangement ops during token dispatching."""
moe_router_fusion: bool = False
"""Enable fusion for MoE TopK routing and aux-loss computation. This is only
supported in TransformerEngine 2.7.0 and above.
"""
moe_apply_probs_on_input: bool = False
"""Apply probs on input of experts instead of applying after activation and glu."""
moe_latent_size: Optional[int] = None
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""
moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""
moe_hybridep_num_sms: int = 16
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""
##################
# Context Parallel
##################
cp_comm_type: Optional[Union[str, List[str]]] = None
"""Inter-gpu communication type for context parallelism.
str: all layers share same communication type.
List[str]: each layer has its separate communication type.
cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be
overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention. The all-gather is not
async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get
full sequence of QKV.
"a2a+p2p": A hierarchical implementation of context parallelism to attention.
It uses A2A communications in low-level CP groups (e.g., via NVLink),
and P2P communications in high-level CP groups (e.g., via IBLink).
"""
##################
# Cuda Graphs
##################
enable_cuda_graph: bool = False
"""DEPRECATED and replaced by cuda_graph_impl.
When set to true, either partial CUDA graph (1/many CUDA graph per layer) or full iteration
CUDA graph (1 CUDA graph for whole iteration excluding optimizer) is enabled. --cuda-graph-scope
determines the scope of graph capture."""
cuda_graph_use_single_mempool: bool = False
"""[For `local` implementation only] When set to true, cudagraphs will be captured inside a
single mempool, in which all cudagraphs may only be used once per step. If false, cudagraphs may
be reused across microbatches. Enabling may reduce cudagraph memory overheads due to memory
fragmentation, however may greatly increase the number of cudagraphs created when the number of
microbatches is high."""
cuda_graph_retain_backward_graph: bool = False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""
cuda_graph_warmup_steps: int = 3
"""Number of warmup steps for CUDA graphs"""
external_cuda_graph: bool = False
"""DEPRECATED and replaced by cuda_graph_impl.
When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
cuda_graph_impl: Literal['none', 'local', 'transformer_engine'] = "none"
"""Determines the CUDA graph capture implementation.
"none": no CUDA graph.
"local": capture the CUDA graph using MCore local implementation. Either partial CUDA graph
(1/many CUDA graph per layer) or full iteration CUDA graph (1 CUDA graph for whole iteration
excluding optimizer) is enabled.
"transformer_engine": capture the CUDA graph using TE make_graphed_callables()."""
cuda_graph_scope: Union[str, CudaGraphScope, List[str], List[CudaGraphScope]] = "full"
"""Determines the CUDA graphs capturing scope.
When cuda_graph_impl is set to "transformer_engine", valid values are "attn", "mlp", "moe",
"moe_router", "moe_preprocess", "mamba". "full" or an empty list means the full layer. "full"
is actually deprecated, but for backward compatibility, we still use "full" as the default
value. It will be transformed to an empty list in __post_init__.
When cuda_graph_impl is set to "local", "full_iteration" can be specified as cuda_graph_scope
to enable whole iteration CUDA graph. All other values enable layerwise CUDA graph."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
flash_decode: bool = False
""" Use the optimized flash decoding kernel during inference. """
batch_invariant_mode: bool = False
"""If true, uses batch-invariant kernels that provide deterministic forward execution regardless
of batch size. This ensures bitwise identical results when the same inputs are processed
in different batch configurations. This will significantly affect speed of
training and inference as the kernels are not full optimized.
Defaults to False."""
use_te_activation_func: bool = False
"""Whether to use ffn activation functions implemented by TransformerEngine"""
use_te_rng_tracker: bool = False
""" Whether to use the TE or MCore version of the RNG tracker. """
inference_rng_tracker: bool = False
""" Whether we should instantiate a separate RNG tracker for inference. """
inference_sampling_seed: int = 42
""" Random seed to use for sampling during inference. """
symmetric_ar_type: Optional[Literal['two_shot', "one_shot", "multimem_all_reduce"]] = None
"""What type of symmetric all reduce to use. The default is None
which is no use of symmetric memory.
"""
nccl_all_reduce_for_prefill: bool = False
"""If True, use NCCL all-reduce kernels when symmetric all-reduce is enabled."""
use_inference_optimized_layers: bool = False
"""If True, use inference optimized transformer layers during inference."""
inference_fuse_tp_communication: bool = False
""" If true, uses a fused reduce-scatter-residual-norm-allgather kernel during inference. """
inference_disable_triton_nvls_kernels: bool = False
""" If true, disables the use of Triton NVLS kernels during inference. """
inference_disable_torch_grouped_mm: bool = False
""" If true, disables torch._grouped_mm in InferenceGroupedMLP,
falling back to TE GroupedGEMM. """
mrope_section: Optional[List[int]] = None
""" Multimodal rope section is for channel dimension of temporal, height and width
in rope calculation. """
is_hybrid_model: bool = False
""" Indicates whether this is a hybrid model. """
mamba_state_dim: int = 128
"""The dimensionality of the state representation in Mamba layers."""
mamba_head_dim: int = 64
"""The dimensionality of the heads in the Mamba layers."""
mamba_num_groups: int = 8
"""The number of groups used in Mamba layers."""
mamba_num_heads: Optional[int] = None
"""The number of heads used in Mamba layers.
If None, the number of heads will be hidden_size * expand // mamba_head_dim."""
use_mamba_mem_eff_path: bool = field(
default=True, metadata={"argparse_meta": {"arg_names": ["--disable-mamba-mem-eff-path"]}}
)
"""Controls usage of the memory efficient path for Mamba layers."""
mlp_chunks_for_prefill: int = 1
"""The number of chunks along the sequence dimension to use for MLP computation
during prefill."""
heterogeneous_block_specs: bool = False
"""Whether to use heterogeneous block specs (nemotron-nas architecture)."""
hetereogenous_dist_checkpoint: bool = False
"""Whether to use heterogenous layers in distributed checkpoint."""
####################
# Quantization
####################
quant_recipe: Optional[RecipeConfig] = None
"""Configuration of any per-module quantization settings to be applied to the model"""
transformer_impl: Literal['local', 'transformer_engine', 'inference_optimized'] = (
"transformer_engine"
)
"""Transformer implementation to use.
Options are 'transformer_engine' for Transformer Engine and 'local' for MCore."""
#####################################
# Fine-grained Activation Offloading
#####################################
fine_grained_activation_offloading: bool = False
"""If True, offload the input of the specified modules to the CPU.
Fine-grained activation offloading is a module-level offloading method
instead of a layer-level offloading method like cpu_offloading."""
offload_modules: Optional[list[str]] = field(default_factory=list)
"""The submodules to offload its input.
choices: "attn_norm", "qkv_linear", "core_attn", "attn_proj",
"mlp_norm", "expert_fc1", "moe_act".
"attn_norm": offload the input of the normalization in the attention part.
"qkv_linear": offload the input of the qkv linear part.
"core_attn": offload the input of the core attention part.
"attn_proj": offload the input of the attn linear projection part.
"mlp_norm": offload the input of the normalization in the mlp part.
"expert_fc1": offload the input of the expert fc1 part.
"moe_act": offload the input of the moe act part.
"""
min_offloaded_tensor_size: int = 1024 * 1024
"""The minimum size of the tensor to be offloaded."""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
super().__post_init__()
# When fp32 residual connections are enabled, pipeline parallel communication must