forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharguments.py
More file actions
3192 lines (2866 loc) · 176 KB
/
arguments.py
File metadata and controls
3192 lines (2866 loc) · 176 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
"""Megatron arguments."""
import argparse
import dataclasses
import json
import os
from pathlib import Path
import re
import types
import torch
import torch.nn.functional as F
from packaging.version import Version as PkgVersion
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.rerun_state_machine import RerunStateMachine
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
from megatron.core.transformer.heterogeneous.heterogeneous_config import (
HeterogeneousTransformerConfig,
MLPConfig,
)
from megatron.core.utils import (
get_torch_version,
is_flashinfer_min_version,
is_te_min_version,
is_torch_min_version,
)
from megatron.core.activations import squared_relu
from megatron.core.fusions.fused_bias_geglu import quick_gelu
from megatron.training.utils import (
get_device_arch_version,
update_use_dist_ckpt,
print_rank_0,
warn_rank_0,
)
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.quantization.utils import (
kitchen_quantization_recipe_config,
load_quantization_recipe,
)
from megatron.training.argument_utils import ArgumentGroupFactory
def add_megatron_arguments(parser: argparse.ArgumentParser):
""""Add Megatron-LM arguments to the given parser."""
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_rl_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_tokenizer_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_mla_args(parser)
parser = _add_experimental_attention_variant_args(parser)
parser = _add_heterogeneous_args(parser)
parser = _add_logging_args(parser)
parser = _add_straggler_detector_args(parser)
parser = _add_workload_inspector_server_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_experimental_args(parser)
parser = _add_one_logger_args(parser)
parser = _add_inprocess_restart_args(parser)
parser = _add_ft_package_args(parser)
parser = _add_rerun_machine_args(parser)
parser = _add_msc_args(parser)
parser = _add_kitchen_quantization_arguments(parser)
parser = _add_sft_args(parser)
return parser
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
parser = add_megatron_arguments(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Experimental yaml
if args.yaml_cfg is not None:
from .yaml_arguments import load_yaml
assert args.yaml_cfg and not args.use_legacy_models, \
"Yaml config is not supported with legacy models."
args = load_yaml(args.yaml_cfg)
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Args to disable MSC
if not args.enable_msc:
MultiStorageClientFeature.disable()
assert MultiStorageClientFeature.is_enabled() is False
warn_rank_0('The MSC feature is disabled.')
return args
def validate_model_config_args_from_heterogeneous_config(args):
"""Validate model config arguments from heterogeneous config.
This function takes model arguments and validates them based on a heterogeneous layer configuration.
The heterogeneous config can be provided either as a path to a JSON file or as an encoded JSON string.
The function enforces certain model architecture choices like SiLU activation, RMSNorm, grouped query attention,
and RoPE positional embeddings. It also sets model dimensions like number of layers, hidden size, and attention heads
based on the heterogeneous config.
Args:
args: Model configuration arguments to be overridden. Expected to have attributes:
- heterogeneous_layers_config_path (str): Path to JSON config file
- heterogeneous_layers_config_encoded_json (str): Encoded JSON config string
Returns:
None
"""
if (
args.heterogeneous_layers_config_path is None
and args.heterogeneous_layers_config_encoded_json is None
):
return
if args.heterogeneous_layers_config_encoded_json is None:
args.heterogeneous_layers_config_encoded_json = Path(
args.heterogeneous_layers_config_path
).read_text()
hf_config_dict = types.SimpleNamespace(**json.loads(args.heterogeneous_layers_config_encoded_json))
assert hf_config_dict.hidden_act == "silu", (
f"hidden_act in heterogeneous config is {hf_config_dict.hidden_act}, should be silu"
)
n_kv_heads_in_group = [
config["attention"]["n_heads_in_group"] for config in hf_config_dict.block_configs
if config["attention"]["n_heads_in_group"] is not None
]
assert all(num == n_kv_heads_in_group[0] for num in n_kv_heads_in_group), "num query head must be consistent across all layers"
args_to_validate = {
"swiglu": True,
"normalization": "RMSNorm",
"group_query_attention": True,
"position_embedding_type": "rope",
"rotary_percent": 1.0,
"use_rope_scaling": True,
"use_rotary_position_embeddings": True,
"num_layers": hf_config_dict.num_hidden_layers,
"hidden_size": hf_config_dict.hidden_size,
"num_attention_heads": hf_config_dict.num_attention_heads,
"untie_embeddings_and_output_weights": not hf_config_dict.tie_word_embeddings,
"rotary_base": hf_config_dict.rope_theta,
"rope_scaling_factor": hf_config_dict.rope_scaling["factor"],
"num_query_groups": hf_config_dict.num_attention_heads // n_kv_heads_in_group[0],
}
incompatible_args = {}
for key, value in args_to_validate.items():
provided_value = getattr(args, key, None)
if provided_value != value:
incompatible_args[key] = (provided_value, value)
if incompatible_args:
incompatible_args_str = ', '.join([
f"{k}: {provided_value} (provided) != {value} (expected)"
for k, (provided_value, value) in incompatible_args.items()
])
raise ValueError(
f"Arguments differ from heterogeneous config: {incompatible_args_str}"
)
def _eval_pattern(pattern):
""" Validate and evaluate a string containing a Python list expression """
assert isinstance(pattern, str)
# validate input, only allow comma, digits, [, ], (, ), +, and *
if bool(re.compile(r'[^,\d\[\]\(\)\+\*]').search(pattern)):
raise ValueError(f"Invalid pattern: {pattern}")
return eval(pattern)
def no_rope_freq_type(x):
""" Controls which layers to skip performing Rotary Position Embedding.
- An integer N: Represents a 1:N ratio, meaning RoPE is skipped every N-1 layers.
- A string "N": Same as above, but provided as a string
- A string containing a Python list expression that defines a custom pattern, e.g.:
"([0]*3+[1]*1)*3" evaluates to [0,0,0,1,0,0,0,1,0,0,0,1]
where 1 indicates rope is skipped on the layer.
This allows defining arbitrary patterns of rope skipping.
The pattern length must match the total number of transformer layers.
Examples:
"([1]+[0]*23)": Only first layer has rope skipped for a 24-layer network.
"([0]*3+[1]*1)*2": Every 4 layers the rope is skipped on the last layer. Repeat twice.
"""
if x is None or isinstance(x, int):
return x
assert isinstance(x, str)
if '[' in x:
# it's a custom pattern
return _eval_pattern(x)
else:
# it's a single int but in str
return int(x)
def moe_freq_type(x):
"""Frequency between MoE layers and Dense layers.
Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers
- A string "N": Same as above, but provided as a string
- A string containing a Python list expression that defines a custom pattern, e.g.:
"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0]
where 1 indicates an expert layer and 0 indicates a dense layer.
This allows defining arbitrary patterns of expert and dense layers.
The pattern length must match the total number of transformer layers.
Examples:
"([0]+[1]*23)": 1 dense layer followed by 23 expert layers
"([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice.
"""
if isinstance(x, int):
return x
assert isinstance(x, str)
if '[' in x:
# it's a custom pattern
return _eval_pattern(x)
else:
# it's a single int but in str
return int(x)
def la_freq_type(x):
"""Frequency between LA (linear attention) layers and SDPA (scaled dot-product attention) layers.
Accepts either:
- An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer
- A string "N": Same as above, but provided as a string
- A string containing a Python list expression that defines a custom pattern, e.g.:
"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0]
where 1 indicates an LA layer and 0 indicates a SDPA layer.
This allows defining arbitrary patterns of LA and SDPA layers.
The pattern length must match the total number of transformer layers.
Examples:
"([0]+[1]*23)": 1 SDPA layer followed by 23 LA layers
"([1]*3+[0]*2)*2": Three LA layers followed by two SDPA layers, repeated twice.
"""
if x is None or isinstance(x, int):
return x
assert isinstance(x, str)
if '[' in x:
# it's a custom pattern
return _eval_pattern(x)
else:
# it's a single int but in str
return int(x)
def tuple_type(x):
"""
Convert a string to a tuple of integers.
Examples:
"1,2,3" -> (1, 2, 3)
"(1,2,3)" -> (1, 2, 3)
"""
if x is None or isinstance(x, tuple):
return x
assert isinstance(x, str)
return tuple(int(i) for i in x.strip('()').split(','))
def validate_args(args, defaults={}):
# Temporary
assert args.non_persistent_ckpt_type in ['global', 'local', None], \
'Currently only global and local checkpoints are supported'
if args.non_persistent_ckpt_type == 'local':
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
except ModuleNotFoundError as e:
raise RuntimeError('nvidia_resiliency_ext is required for local checkpointing') from e
# validate model config args from heterogeneous config (if provided).
validate_model_config_args_from_heterogeneous_config(args)
# Set args.use_dist_ckpt from args.ckpt_format.
if args.use_legacy_models:
assert args.ckpt_format == "torch", \
"legacy model format only supports the 'torch' checkpoint format."
update_use_dist_ckpt(args)
total_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size
# Total model size.
assert args.world_size % total_model_size == 0, (
f"world size ({args.world_size}) is not divisible by total_model_size ({total_model_size=})"
)
if args.attention_backend == AttnBackend.local:
assert args.spec[0] == 'local' , '--attention-backend local is only supported with --spec local'
# Pipeline model parallel size.
args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size
total_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size
args.data_parallel_size = args.world_size // total_model_size
if args.perform_rl_step:
# ----------------------------------------------------------------
# CUDA graphs
#
# --cuda-graph-impl controls whether CUDA graphs are built.
# The sweep of various inference CUDA graphs is built inside inference, not the RL loop.
# Both training and inference CUDA graphs are gated by this flag.
#
# --rl-training-cuda-graphs controls whether CUDA graphs are used during training.
# Toggling CUDA graphs on and off is done inside the RL loop.
#
# --rl-persist-cuda-graphs controls whether CUDA graphs are built once, or repeatedly.
# When this flag is True, inference requires static memory pointers for the KV cache.
# When this flag is False, inference is in charge of deleting/rebuilding CUDA graphs.
#
# KV cache management (--rl-kv-cache-management-mode)
#
# Inference initializes the KV cache, inside either a normal memory pool, UVM, or TMS.
#
# On suspend (inference -> training):
# "persist" — no-op; KV cache stays on GPU.
# "offload" — KV cache is offloaded to CPU.
# "recompute" — KV cache is deleted entirely.
#
# On resume (training → inference):
# "persist" — no-op; KV cache is already on GPU.
# "offload" — KV cache is restored from CPU.
# "recompute" — KV cache is recomputed from scratch.
# ----------------------------------------------------------------
# Persisting CGs only makes sense if we build any CGs.
assert not args.rl_persist_cuda_graphs or args.cuda_graph_impl != "none", (
"--rl-persist-cuda-graphs is set but no CUDA graphs are being built."
)
# Training CGs only makes sense if we build any CGs.
assert not args.rl_training_cuda_graphs or args.cuda_graph_impl != "none", (
"--rl-training-cuda-graphs is set but no CUDA graphs are being built."
)
# If CUDA graphs persist and KV cache memory address is not static, we need
# either UVM or torch_memory_saver to maintain memory address stability for CGs.
if args.rl_persist_cuda_graphs and args.rl_kv_cache_management_mode != "persist":
try:
from torch_memory_saver import torch_memory_saver
except ImportError:
assert args.inference_dynamic_batching_unified_memory_level > 0, (
"Persisting CUDA graphs requires static KV cache memory. Use "
"--rl-kv-cache-management-mode=persist, UVM, or install torch_memory_saver."
)
# Offload mode requires CG persistence: CG recapture runs dummy forward
# passes that corrupt the preserved KV data.
assert (
(not args.rl_kv_cache_management_mode == "offload") or (args.rl_persist_cuda_graphs)
), "--rl-kv-cache-management-mode=offload requires --rl-persist-cuda-graphs"
# There's no need to manually offload the KV cache with UVM.
assert not (
args.inference_dynamic_batching_unified_memory_level > 0
and args.rl_kv_cache_management_mode == "offload"
), "--rl-kv-cache-management-mode=offload is incompatible with UVM"
# We currently cannot recapture CGs in offload mode.
assert not(
not args.rl_persist_cuda_graphs and args.rl_kv_cache_management_mode == "offload"
), "Cannot recapture CUDA graphs while offloading KV cache."
# Validate inference model offloading - requires either UVM or torch_memory_saver
if args.rl_offload_inference_model_weights_when_idle:
if args.rl_inference_model_unified_memory_level != 1:
# Not using UVM, so we need torch_memory_saver
try:
from torch_memory_saver import torch_memory_saver
except ImportError:
raise AssertionError(
"To use --rl-offload-inference-model-weights-when-idle without UVM "
"(--rl-inference-model-unified-memory-level=1), `torch_memory_saver` must be "
"installed. See https://github.com/fzyzcjy/torch_memory_saver."
)
args.grpo_samples_per_iteration = args.grpo_prompts_per_step * args.grpo_group_size
num_generated_samples_per_inference_iteration = (
args.grpo_samples_per_iteration * args.grpo_iterations)
# Ensure that the number of prompts we collect is a multiple of the global batch size.
# TODO: Make this account for batch size rampup?
assert num_generated_samples_per_inference_iteration % args.global_batch_size == 0, \
f"grpo_group_size * grpo_prompts_per_step * grpo_iterations should be divisible by global_batch_size"
# For now only exit/checkpoint on iterations where we generate data. We don't currently
# have a way to checkpoint the generated data.
num_training_iterations_per_inference_iteration = (
num_generated_samples_per_inference_iteration // args.global_batch_size)
if args.exit_interval is not None:
assert args.exit_interval % num_training_iterations_per_inference_iteration == 0, \
f"exit_interval should be divisible by number of global batches per inference iteration."
if args.save_interval is not None:
assert args.save_interval % num_training_iterations_per_inference_iteration == 0, \
f"save_interval should be divisible by number of global batches per inference iteration."
if args.rl_use_sequence_packing:
assert args.micro_batch_size == 1, \
"micro_batch_size must be 1 when using sequence packing. To increase compute per micro batch increase the sequence length."
print_rank_0('using world size: {}, data-parallel size: {}, '
'context-parallel size: {}, '
'hierarchical context-parallel sizes: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {}'.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
args.hierarchical_context_parallel_sizes,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size))
# Checks.
if args.hierarchical_context_parallel_sizes:
from numpy import prod
assert args.context_parallel_size == prod(args.hierarchical_context_parallel_sizes)
if "a2a+p2p" in args.cp_comm_type:
assert args.hierarchical_context_parallel_sizes is not None, \
"--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm"
if args.expert_tensor_parallel_size is None:
args.expert_tensor_parallel_size = args.tensor_model_parallel_size
# Deprecated arguments.
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
print_rank_0('--checkpoint-activations is no longer valid, use --recompute-activations, '
'or, for more control, --recompute-granularity and --recompute-method.')
exit()
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
if args.enable_cuda_graph or args.external_cuda_graph:
assert (
args.cuda_graph_impl == "none"
), "Do not use --enable-cuda-graph or --external-cuda-graph with --cuda-graph-impl."
assert (
not args.enable_cuda_graph or not args.external_cuda_graph
), "--enable-cuda-graph and --external-cuda-graph cannot be enabled at the same time."
if args.enable_cuda_graph:
print_rank_0(
'--enable-cuda-graph is deprecated, use --cuda-graph-impl=local instead.', args.rank
)
args.cuda_graph_impl = "local"
del args.enable_cuda_graph
if args.external_cuda_graph:
print_rank_0(
'--external-cuda-graph is deprecated, use --cuda-graph-impl=transformer_engine instead.',
args.rank,
)
args.cuda_graph_impl = "transformer_engine"
del args.external_cuda_graph
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key, None) is not None:
warn_rank_0('Overriding default arguments for {key}:{v} '
'with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)))
else:
setattr(args, key, defaults[key])
if args.data_path is not None and args.split is None:
legacy_default_split_value = '969, 30, 1'
warn_rank_0('Please specify --split when using --data-path. Using legacy default value '
f'of "{legacy_default_split_value}"')
args.split = legacy_default_split_value
use_data_path = (args.data_path is not None) or (args.data_args_path is not None)
if use_data_path:
# Exactly one of the two has to be None if we use it.
assert (args.data_path is None) or (args.data_args_path is None)
use_per_split_data_path = any(
elt is not None
for elt in [args.train_data_path, args.valid_data_path, args.test_data_path]) or \
args.per_split_data_args_path is not None
if use_per_split_data_path:
# Exactly one of the two has to be None if we use it.
assert any(elt is not None
for elt in [args.train_data_path, args.valid_data_path, args.test_data_path]) is False or \
args.per_split_data_args_path is None
if args.phase_transition_iterations:
args.phase_transition_iterations = sorted(
int(x.strip()) for x in args.phase_transition_iterations.split(",")
)
assert args.rampup_batch_size is None, "multi-phase training does not support batch size ramp-up"
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
print_rank_0('setting global batch size to {}'.format(args.global_batch_size))
assert args.global_batch_size > 0
# === Hybrid layer pattern: deprecation handling and validation ===
# Backward compat: --hybrid-override-pattern is deprecated in favor of --hybrid-layer-pattern
used_hybrid_override_pattern = False
if args.hybrid_override_pattern is not None:
assert args.hybrid_layer_pattern is None, (
'--hybrid-override-pattern and --hybrid-layer-pattern cannot both be specified. '
'--hybrid-override-pattern is deprecated; use --hybrid-layer-pattern instead.'
)
warn_rank_0(
"--hybrid-override-pattern is deprecated. Use --hybrid-layer-pattern instead.",
args.rank,
)
args.hybrid_layer_pattern = args.hybrid_override_pattern
used_hybrid_override_pattern = True
if args.mtp_hybrid_override_pattern is not None:
warn_rank_0(
"--mtp-hybrid-override-pattern is deprecated. "
"For new hybrid models with MTP, use unified --hybrid-layer-pattern instead. "
"Example: 'M*M*/MM/MM' means main='M*M*', MTP pattern='MM' with 2 depths. "
"This argument is kept only for loading old checkpoints.",
args.rank,
)
from megatron.core.ssm.mamba_hybrid_layer_allocation import (
Symbols, parse_hybrid_pattern, get_hybrid_total_layer_count,
get_hybrid_total_pipeline_segment_count,
)
sep = Symbols.MTP_SEPARATOR
# Backward compat: convert legacy mtp_hybrid_override_pattern to unified format
if (
args.mtp_hybrid_override_pattern is not None
and args.mtp_num_layers is not None
and args.mtp_num_layers > 0
and (args.hybrid_layer_pattern is None or sep not in args.hybrid_layer_pattern)
):
main_pattern = args.hybrid_layer_pattern or ''
mtp_pattern = args.mtp_hybrid_override_pattern
args.hybrid_layer_pattern = main_pattern + sep + sep.join([mtp_pattern] * args.mtp_num_layers)
args.mtp_hybrid_override_pattern = None
print_rank_0(f"Converted legacy MTP pattern to unified: {args.hybrid_layer_pattern}")
if args.hybrid_layer_pattern is not None:
# Derive num_layers from pattern; hybrid_layer_pattern always overrides --num-layers when
# both are present (e.g. when loading from checkpoint with --use-checkpoint-args).
num_layers_in_pattern = get_hybrid_total_layer_count(args.hybrid_layer_pattern)
if args.num_layers is not None and args.num_layers != num_layers_in_pattern:
warn_rank_0(
f'--hybrid-layer-pattern is set; ignoring --num-layers ({args.num_layers}) and '
f'using the layer count derived from the pattern ({num_layers_in_pattern}).',
args.rank,
)
args.num_layers = num_layers_in_pattern
# first/last pipeline num layers are incompatible with pipe-separated patterns
# (the pipe separators already define the pipeline layout explicitly), but are
# allowed for pipe-free patterns where they control uneven PP splitting.
has_pipes = Symbols.PIPE in args.hybrid_layer_pattern.split(sep)[0]
if has_pipes:
assert args.decoder_first_pipeline_num_layers is None, (
'If --hybrid-layer-pattern contains pipe separators, '
'--decoder-first-pipeline-num-layers should not be specified '
'as the pipeline layout is explicitly defined.'
)
assert args.decoder_last_pipeline_num_layers is None, (
'If --hybrid-layer-pattern contains pipe separators, '
'--decoder-last-pipeline-num-layers should not be specified '
'as the pipeline layout is explicitly defined.'
)
assert args.num_layers_per_virtual_pipeline_stage is None, (
'--num-layers-per-virtual-pipeline-stage should not be used with '
'--hybrid-layer-pattern. To specify virtual pipelining, describe a number of '
'pipeline segments in --hybrid-layer-pattern that is a multiple of '
'--pipeline-model-parallel-size greater than 1'
)
assert args.num_virtual_stages_per_pipeline_rank is None, (
'--num-virtual-stages-per-pipeline-rank should not be used with '
'--hybrid-layer-pattern. Virtual pipeline stages are derived from the '
'number of | segments in the pattern.'
)
assert args.pipeline_model_parallel_layout is None, (
'--pipeline-model-parallel-layout should not be used with --hybrid-layer-pattern. '
'Pipeline stage layout is defined by | separators in the pattern.'
)
assert not args.account_for_embedding_in_pipeline_split, (
'--account-for-embedding-in-pipeline-split should not be used with '
'--hybrid-layer-pattern. Pipeline stage layout is defined by | separators '
'in the pattern.'
)
assert not args.account_for_loss_in_pipeline_split, (
'--account-for-loss-in-pipeline-split should not be used with '
'--hybrid-layer-pattern. Pipeline stage layout is defined by | separators '
'in the pattern.'
)
# Derive VPP from pipe segments in the pattern
hybrid_pipeline_segments = get_hybrid_total_pipeline_segment_count(
args.hybrid_layer_pattern
)
if hybrid_pipeline_segments == 1 and args.transformer_pipeline_model_parallel_size > 1:
# No pipes in pattern -- PP will be handled by select_pipeline_segment
# at model init time (for backwards compatibility).
args.virtual_pipeline_model_parallel_size = None
else:
assert hybrid_pipeline_segments % args.transformer_pipeline_model_parallel_size == 0, (
'The number of hybrid pipeline segments described by --hybrid-layer-pattern must '
'be evenly divisible by --pipeline-model-parallel-size. '
f'Got {hybrid_pipeline_segments} segments and '
f'{args.transformer_pipeline_model_parallel_size} pipeline parallel size.'
)
if hybrid_pipeline_segments > args.transformer_pipeline_model_parallel_size:
# Must be set here in order to assign virtual parallel ranks in
# training.py/get_model
args.virtual_pipeline_model_parallel_size = (
hybrid_pipeline_segments // args.transformer_pipeline_model_parallel_size
)
else:
args.virtual_pipeline_model_parallel_size = None
# Infer mtp_num_layers from unified pattern
if args.hybrid_layer_pattern and sep in args.hybrid_layer_pattern:
parsed = parse_hybrid_pattern(args.hybrid_layer_pattern)
if parsed.mtp_pattern and parsed.mtp_num_depths > 0:
inferred_mtp_num_layers = parsed.mtp_num_depths
if args.mtp_num_layers is None:
args.mtp_num_layers = inferred_mtp_num_layers
elif args.mtp_num_layers != inferred_mtp_num_layers:
warn_rank_0(
f"--mtp-num-layers ({args.mtp_num_layers}) conflicts with "
f"MTP depth count ({inferred_mtp_num_layers}) in pattern "
f"'{args.hybrid_layer_pattern}'. "
f"Using the inferred value ({inferred_mtp_num_layers}).",
args.rank
)
args.mtp_num_layers = inferred_mtp_num_layers
# MTP validation
if args.mtp_num_layers:
assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)."
assert args.position_embedding_type == "rope" or args.position_embedding_type == "none", (
f"Multi-Token Prediction (MTP) is not supported with {args.position_embedding_type} position embedding type."
+ f"The supported position embedding types are rope and none."
)
# Validate MTP args for hybrid vs non-hybrid models
if args.hybrid_layer_pattern is not None:
# Mamba/hybrid model MTP validation
if args.mtp_num_layers and not (args.hybrid_layer_pattern and sep in args.hybrid_layer_pattern):
# Hybrid model wants MTP but no unified pattern - check for legacy args
if args.mtp_hybrid_override_pattern is None:
warn_rank_0(
"Hybrid model with --mtp-num-layers but no MTP pattern. "
"Use unified --hybrid-layer-pattern with '/' separator (e.g., 'M*M*/MM/MM') "
"or legacy --mtp-hybrid-override-pattern for old checkpoints.",
args.rank
)
else:
# Non-hybrid (GPT) model MTP validation
if args.mtp_hybrid_override_pattern is not None:
warn_rank_0(
"--mtp-hybrid-override-pattern is for Mamba/hybrid models only. "
"For GPT models, MTP replicates the main transformer layer structure. "
"This argument will be ignored.",
args.rank
)
# === End of hybrid layer pattern: deprecation handling and validation ===
# Uneven virtual pipeline parallelism
assert (
int(args.num_layers_per_virtual_pipeline_stage is not None)
+ int(args.num_virtual_stages_per_pipeline_rank is not None)
+ int(args.pipeline_model_parallel_layout is not None)
) <= 1, (
'No more than one of the following arguments can be set at the same time: '
'--num-layers-per-virtual-pipeline-stage, --num-virtual-stages-per-pipeline-rank,'
'--pipeline-model-parallel-layout. '
f'{args.num_layers_per_virtual_pipeline_stage=}, '
f'{args.num_virtual_stages_per_pipeline_rank=}, '
f'{args.pipeline_model_parallel_layout=}.'
)
if args.pipeline_model_parallel_layout is not None:
# Parse the input flattened layout to a list and get the vpp size.
# We will validate the layout more carefully in the TransformerConfig constructor.
num_stages = PipelineParallelLayerLayout.get_num_stages_from_str(args.pipeline_model_parallel_layout)
assert num_stages % args.pipeline_model_parallel_size == 0, (
f"The length of pipeline_model_parallel_layout must be divisible"
f" by pipeline_model_parallel_size ({num_stages=},"
f" {args.pipeline_model_parallel_size=})"
)
args.virtual_pipeline_model_parallel_size = num_stages // args.pipeline_model_parallel_size
if args.virtual_pipeline_model_parallel_size == 1:
args.virtual_pipeline_model_parallel_size = None
elif args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None:
if args.num_virtual_stages_per_pipeline_rank is None:
assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \
'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism'
if args.num_layers is not None:
num_layers = args.num_layers
else:
num_layers = args.decoder_num_layers
if args.account_for_embedding_in_pipeline_split:
num_layers += 1
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers of the model must be divisible pipeline model parallel size'
num_layers_per_pipeline_stage = num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank
if args.virtual_pipeline_model_parallel_size == 1:
args.virtual_pipeline_model_parallel_size = None
else:
# Only set VPP to None if it wasn't already derived from --hybrid-layer-pattern
if args.hybrid_layer_pattern is None:
args.virtual_pipeline_model_parallel_size = None
if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:
# Divisibility check not applicable for T5 models which specify encoder_num_layers
# and decoder_num_layers, or for hybrid models using --hybrid-layer-pattern.
if args.num_layers is not None and args.hybrid_layer_pattern is None:
num_layers = args.num_layers
if args.account_for_embedding_in_pipeline_split:
num_layers += 1
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size'
if args.virtual_pipeline_model_parallel_size is not None:
if args.overlap_p2p_comm:
assert args.pipeline_model_parallel_size > 1, \
'When interleaved schedule is used, pipeline-model-parallel size '\
'should be greater than 1'
else:
assert args.pipeline_model_parallel_size > 2, \
'When interleaved schedule is used and p2p communication overlap is disabled, '\
'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\
'p2p sends and recvs between same 2 ranks per communication batch'
else:
# Overlap P2P communication is disabled if not using the interleaved schedule.
args.overlap_p2p_comm = False
args.align_param_gather = False
# Only print warning if PP size > 1.
if args.rank == 0 and args.pipeline_model_parallel_size > 1:
print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False '
'since non-interleaved schedule does not support overlapping p2p communication '
'and aligned param AG')
print_rank_0(
f"Number of virtual stages per pipeline stage: {args.virtual_pipeline_model_parallel_size}"
)
if args.overlap_param_gather:
assert args.use_distributed_optimizer or args.use_megatron_fsdp \
or args.optimizer == 'dist_muon', \
'--overlap-param-gather only supported with distributed optimizer, megatron fsdp, or dist_muon'
assert args.overlap_grad_reduce, \
'Must use --overlap-param-gather with --overlap-grad-reduce'
assert not args.use_legacy_models, \
'--overlap-param-gather only supported with MCore models'
if args.use_torch_fsdp2:
assert is_torch_min_version("2.4.0"), \
'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.'
assert args.pipeline_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with pipeline parallelism'
assert args.expert_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with expert parallelism'
assert not args.use_distributed_optimizer, \
"--use-torch-fsdp2 is not supported with MCore's distributed optimizer"
assert not args.gradient_accumulation_fusion, \
'--use-torch-fsdp2 is not supported with gradient accumulation fusion'
assert args.ckpt_format in ('torch_dist', 'torch_dcp'), \
'--use-torch-fsdp2 requires --ckpt-format torch_dist or torch_dcp'
assert args.untie_embeddings_and_output_weights, \
'--use-torch-fsdp2 requires --untie-embeddings-and-output-weights'
assert not args.fp16, \
'--use-torch-fsdp2 not supported with fp16 yet'
assert os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one'
if args.fp8_param_gather and is_te_min_version("2.0.0"):
args.fp8_param_gather = False
warn_rank_0(
'FSDP2 FP8 param gather is not supported yet in TE 2.0, will fallback to bf16'
'all_gather instead, turning off fp8_param_gather',
args.rank,
)
if args.fp4_param and not is_te_min_version("2.7.0.dev0"):
raise ValueError("--fp4-param requires Transformer Engine >= 2.7.0.dev0.")
if args.overlap_param_gather_with_optimizer_step:
assert args.use_distributed_optimizer, \
'--overlap-param-gather-with-optimizer-step only supported with distributed optimizer'
assert args.overlap_param_gather, \
'Must use --overlap-param-gather-with-optimizer-step with --overlap-param-gather'
assert args.virtual_pipeline_model_parallel_size is not None, \
'--overlap-param-gather-with-optimizer-step only supported with interleaved pipeline parallelism'
assert not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet'
# Map string data-type to torch.dtype.
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]
args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
args.mamba_inference_conv_states_dtype = map_dtype(args.mamba_inference_conv_states_dtype)
args.mamba_inference_ssm_states_dtype = map_dtype(args.mamba_inference_ssm_states_dtype)
args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype)
args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype)
args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype)
if args.grad_reduce_in_bf16:
assert args.megatron_fsdp_grad_comm_dtype == torch.bfloat16, \
"When --grad-reduce-in-bf16 is set, --megatron-fsdp-grad-comm-dtype must be bfloat16"
if args.fp8_param_gather:
assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \
'--fp8-param-gather only supported with distributed optimizer, torch fsdp2, megatron fsdp, or inference mode'
# FP4 and FP8 are mutually exclusive
if args.fp4 and args.fp8:
raise ValueError("--fp4-format and --fp8-format cannot be used simultaneously. Please choose one.")
# FP4 param requires FP4 mode
if args.fp4_param and not args.fp4:
raise ValueError("--fp4-param-gather must be used together with --fp4-format.")
# FP4 requires TE >= 2.7.0.dev0
if args.fp4 and not is_te_min_version("2.7.0.dev0"):
raise ValueError("--fp4-format requires Transformer Engine >= 2.7.0.dev0 for NVFP4BlockScaling support.")
if (
args.fp8_recipe == 'mxfp8'
and args.transformer_impl == 'inference_optimized'
and not is_flashinfer_min_version("0.6.4")
):
raise ValueError("MXFP8 with inference optimized layers requires FlashInfer >= 0.6.4")
if args.use_megatron_fsdp:
# NOTE: The flag `use_custom_fsdp` is deprecated and will be removed in future versions.
# Please use `use_megatron_fsdp` instead, as all functionality will be migrated there.
# Future updates will drop support for `use_custom_fsdp` to avoid confusion.
args.use_custom_fsdp = True
if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]:
warn_rank_0(
'Please make sure your TransformerEngine support FSDP + gradient accumulation fusion',
args.rank,
)
if args.data_parallel_sharding_strategy == "optim_grads_params":
assert args.check_weight_hash_across_dp_replicas_interval is None, \
'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params'
assert os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one'
assert args.ckpt_format == "fsdp_dtensor", \
"Megatron FSDP only supports fsdp_dtensor checkpoint format"
if args.fsdp_manual_registration:
assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP"
assert args.nccl_ub, "FSDP manual registration is only supported with nccl-ub option"
if args.use_megatron_fsdp:
args.reuse_grad_buf_for_mxfp8_param_ag = False
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
# Turn off checking for NaNs in loss and grads if using dynamic loss scaling,
# where NaNs in grads / loss are signal to the loss scaler.
if not args.loss_scale:
args.check_for_nan_in_loss_and_grad = False
warn_rank_0('Setting args.check_for_nan_in_loss_and_grad to False since '
'dynamic loss scaling is being used')
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if args.accumulate_allreduce_grads_in_fp32:
assert args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set"
if args.grad_reduce_in_bf16:
args.accumulate_allreduce_grads_in_fp32 = False
elif not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True
print_rank_0('accumulate and all-reduce gradients in fp32 for bfloat16 data type.')
if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope:
assert not args.check_for_nan_in_loss_and_grad, \
"--no-check-for-nan-in-loss-and-grad should be set with --cuda-graph-scope=full_iteration for training. Note: If you are trying to use full_iteration CUDA graphs for inference, please use --cuda-graph-scope full_iteration_inference instead"
if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration_inference in args.cuda_graph_scope:
assert args.fp8 is None, \
"fp8 is not supported with inference dynamic batching and full_iteration_inference CUDA graph"
if args.cuda_graph_impl == 'local':
assert args.inference_dynamic_batching_num_cuda_graphs > 0 or args.inference_dynamic_batching_num_cuda_graphs == -1, \
'inference_dynamic_batching_num_cuda_graphs should be a positive integer or -1' \
'-1 means that we will automatically determine the number of CUDA graphs to capture based on the `max_requests` value.'
print_rank_0('using {} for parameters ...'.format(args.params_dtype))
if args.dataloader_type is None:
args.dataloader_type = 'single'
# data
assert args.num_dataset_builder_threads > 0
# Consumed tokens.
args.consumed_train_samples = 0
args.skipped_train_samples = 0
args.consumed_valid_samples = 0
if args.rl_use_sequence_packing:
args.consumed_train_bins = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'