-
Notifications
You must be signed in to change notification settings - Fork 356
Expand file tree
/
Copy pathsetup.py
More file actions
1198 lines (1031 loc) · 45.3 KB
/
setup.py
File metadata and controls
1198 lines (1031 loc) · 45.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import time
import warnings
from typing import Any, Callable, Optional, TypeVar
import torch
from megatron.bridge import AutoBridge
from megatron.bridge.models.model_provider import get_model
from megatron.bridge.peft.lora import LoRA
from megatron.bridge.training import fault_tolerance
from megatron.bridge.training.checkpointing import (
_load_checkpoint_from_path,
checkpoint_exists,
init_checkpointing_context,
load_checkpoint,
)
from megatron.bridge.training.config import (
CheckpointConfig,
ConfigContainer,
DistributedDataParallelConfig,
LoggerConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainingConfig,
)
from megatron.bridge.training.initialize import (
initialize_megatron,
set_jit_fusion_options,
)
from megatron.bridge.training.optim import setup_optimizer
from megatron.bridge.training.setup import (
_create_peft_pre_wrap_hook,
_update_model_config_funcs,
)
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer
from megatron.bridge.training.utils.pg_utils import get_pg_collection
from megatron.bridge.utils.instantiate_utils import InstantiationMode
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
from megatron.core import parallel_state
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from transformers import PreTrainedTokenizerBase
from nemo_rl.distributed.model_utils import patch_gpt_model_forward_for_linear_ce_fusion
try:
from megatron.core.distributed import (
TorchFullyShardedDataParallel as torch_FSDP, # noqa: F401 unused-import
)
HAVE_FSDP2 = True
except ImportError:
HAVE_FSDP2 = False
from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams
from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig
from nemo_rl.models.megatron.draft.utils import (
build_draft_model,
find_draft_owner_chunk,
get_attached_draft_model,
)
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
get_megatron_checkpoint_dir,
)
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)
def destroy_parallel_state():
"""Safely destroy parallel state and reset async call tracking.
This function is called during initialization to clean up temporary distributed
state from model import operations. Resetting async call tracking ensures that
when the main Megatron distributed context is created, all ranks start with
consistent call_idx values for async checkpointing.
"""
if torch.distributed.is_initialized():
try:
torch.distributed.barrier()
torch.distributed.destroy_process_group()
except:
pass # Ignore errors if already destroyed
if hasattr(parallel_state, "destroy_model_parallel"):
try:
parallel_state.destroy_model_parallel()
except:
pass # Ignore errors if already destroyed
# Also reset the Megatron async calls queue if it exists
try:
import megatron.training.async_utils as megatron_async_utils
from megatron.core.dist_checkpointing.strategies.async_utils import (
AsyncCallsQueue,
)
# Clean up any existing async callers first
old_call_idx = getattr(
megatron_async_utils._async_calls_queue, "call_idx", None
)
if megatron_async_utils._async_calls_queue is not None:
num_unfinalized = (
megatron_async_utils._async_calls_queue.get_num_unfinalized_calls()
)
if num_unfinalized > 0:
print(
f"[WARNING] Resetting Megatron async calls queue with {num_unfinalized} unfinalized calls"
)
try:
megatron_async_utils._async_calls_queue.close()
except:
pass # Ignore errors during cleanup
# Reset the Megatron global async calls queue as well
megatron_async_utils._async_calls_queue = AsyncCallsQueue()
print(
f"[DEBUG] Reset Megatron async calls queue (old call_idx: {old_call_idx})"
)
except ImportError:
pass
def setup_distributed() -> None:
"""Handle NCCL settings, dtype mapping, and basic config setup."""
# Disable dynamo autotune_local_cache to avoid crash when there's already a cache
# with different order of node_bundles
configure_dynamo_cache()
# Ensure clean slate before import
destroy_parallel_state()
# Need to initialize the process group before calling into Megatron-Bridge, otherwise Megatron-Bridge will try to set an incorrect device
torch.distributed.init_process_group("nccl")
def validate_and_set_config(
config,
rank,
hf_model_name,
pretrained_path,
weights_path,
optimizer_path,
):
# Handle generation configuration
is_generation_colocated = None
sampling_params = None
if "generation" in config and config["generation"] is not None:
generation_cfg = config["generation"]
# set generation colocated
is_generation_colocated = generation_cfg["colocated"]["enabled"]
# set sampling params
sampling_params = TrainingSamplingParams(
top_k=generation_cfg["top_k"],
top_p=generation_cfg["top_p"],
temperature=generation_cfg["temperature"],
)
# Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator.
# See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details.
if not is_generation_colocated:
os.environ["NCCL_CUMEM_ENABLE"] = "1"
# Setup data types
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
dtype = dtype_map[config["precision"]]
# Optimizer configuration
optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"]
offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"]
# Reward models are not yet supported with Megatron.
if "reward_model_cfg" in config and config["reward_model_cfg"]["enabled"]:
raise NotImplementedError(
"Reward models are not yet supported with the Megatron backend, this issue is "
"tracked in https://github.com/NVIDIA-NeMo/RL/issues/720"
)
# Validate yarn rope_scaling fields are fully specified
rope_scaling = (config.get("hf_config_overrides") or {}).get("rope_scaling") or {}
if rope_scaling.get("rope_type") == "yarn":
_YARN_REQUIRED_FIELDS = (
"factor",
"rope_theta",
"original_max_position_embeddings",
"truncate",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
missing = [f for f in _YARN_REQUIRED_FIELDS if f not in rope_scaling]
assert not missing, (
f"rope_scaling.rope_type is 'yarn' but the following required fields are not set: "
f"{missing}. Please specify all of {list(_YARN_REQUIRED_FIELDS)} in "
f"policy.hf_config_overrides.rope_scaling."
)
megatron_cfg, model_cfg = setup_model_config(
config,
rank,
dtype,
hf_model_name,
pretrained_path,
weights_path,
optimizer_path,
)
final_padded_vocab_size = calculate_padded_vocab_size(
megatron_cfg.model.vocab_size,
megatron_cfg.model.make_vocab_size_divisible_by,
config["megatron_cfg"]["tensor_model_parallel_size"],
)
return RuntimeConfig(
megatron_cfg,
model_cfg,
dtype,
optimizer_cpu_offload,
offload_optimizer_for_logprob,
is_generation_colocated,
sampling_params,
final_padded_vocab_size,
)
def _canonicalize_hf_config_overrides(overrides: dict[str, Any]) -> str:
"""Return a stable JSON string for hf_config_overrides."""
return json.dumps(
overrides, sort_keys=True, separators=(",", ":"), ensure_ascii=True
)
def _get_hf_config_overrides_hash(overrides: dict[str, Any]) -> str:
"""Return a short stable hash for hf_config_overrides."""
canonical = _canonicalize_hf_config_overrides(overrides)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:12]
def validate_model_paths(config: PolicyConfig) -> tuple[str, str, bool]:
"""Validate and setup model paths."""
# cfg["model_name"] is allowed to be either an HF model name or a path to an HF checkpoint
hf_model_name = config["model_name"]
hf_config_overrides = config.get("hf_config_overrides", {}) or {}
# Check if the checkpoint already exists
hf_model_subdir = hf_model_name
if os.path.exists(hf_model_name):
hf_model_subdir = f"model_{hf_model_subdir.replace('/', '_')}"
if hf_config_overrides:
overrides_hash = _get_hf_config_overrides_hash(hf_config_overrides)
hf_model_subdir = f"{hf_model_subdir}__hfovr_{overrides_hash}"
pretrained_path = os.path.join(get_megatron_checkpoint_dir(), hf_model_subdir)
pt_checkpoint_exists = os.path.exists(pretrained_path) and os.path.exists(
os.path.join(pretrained_path, "iter_0000000")
)
return hf_model_name, pretrained_path, pt_checkpoint_exists
def setup_model_config(
config: PolicyConfig,
rank,
dtype,
hf_model_name: str,
pretrained_path: str,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
) -> tuple[ConfigContainer, Any]:
"""Handle all the model configuration logic."""
# Load pretrained run config
pretrained_run_config = os.path.join(
pretrained_path, "iter_0000000/run_config.yaml"
)
if not os.path.exists(pretrained_run_config):
raise FileNotFoundError(
f"Pretrained run config not found at {pretrained_run_config} on rank={rank}. "
"This usually means that the one-time HF->mcore conversion on rank=0 saved to a directory "
"not being mounted on this node. Please check"
)
try:
cfg_from_pretrained = ConfigContainer.from_yaml(
pretrained_run_config, mode=InstantiationMode.STRICT
)
except Exception as e:
# Add helpful context as a note to the exception
e.add_note(
f"\n{'=' * 80}\n"
f"NOTE: A common cause of this error is when the HF->mcore converted checkpoint is\n"
f"created with an older version of megatron-bridge.\n"
f"If this checkpoint is old or was generated by a different code version,\n"
f"try deleting it and rerunning the code.\n"
f"The checkpoint will be automatically regenerated with the current version.\n\n"
f"Checkpoint location: {pretrained_path}\n"
f"{'=' * 80}"
)
raise
model_cfg = cfg_from_pretrained.model
cfg_from_pretrained.logger = LoggerConfig()
# Apply parallelism settings
_apply_parallelism_config(model_cfg, config)
# Apply MoE settings
_apply_moe_config(model_cfg, config)
# Apply MTP settings
_apply_mtp_config(model_cfg, config)
# Apply precision settings
_apply_precision_config(model_cfg, config, dtype)
# Apply performance settings
_apply_performance_config(model_cfg, config)
# Validate optimizer configuration
_validate_optimizer_config(config)
# Optional layernorm epsilon
if "layernorm_epsilon" in config["megatron_cfg"]:
model_cfg.layernorm_epsilon = config["megatron_cfg"]["layernorm_epsilon"]
# Validate chunking configuration
_validate_chunking_config(config)
# Create checkpoint configs
checkpoint_config = _create_checkpoint_config(
pretrained_path, weights_path, optimizer_path
)
# Validate training configuration
_validate_training_config(config, model_cfg)
# Create final megatron config
megatron_cfg = _create_megatron_config(
model_cfg, checkpoint_config, config, hf_model_name, dtype
)
_validate_dtype_config(dtype, megatron_cfg.model, megatron_cfg.optimizer)
return megatron_cfg, model_cfg
def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None:
"""Apply tensor/pipeline/context parallelism configuration."""
model_cfg.tensor_model_parallel_size = config["megatron_cfg"][
"tensor_model_parallel_size"
]
model_cfg.pipeline_model_parallel_size = config["megatron_cfg"][
"pipeline_model_parallel_size"
]
model_cfg.num_layers_in_first_pipeline_stage = config["megatron_cfg"][
"num_layers_in_first_pipeline_stage"
]
model_cfg.num_layers_in_last_pipeline_stage = config["megatron_cfg"][
"num_layers_in_last_pipeline_stage"
]
model_cfg.sequence_parallel = config["megatron_cfg"]["sequence_parallel"]
model_cfg.context_parallel_size = config["megatron_cfg"]["context_parallel_size"]
if model_cfg.context_parallel_size > 1:
assert config["sequence_packing"]["enabled"], (
"Sequence Packing must be enabled to use Context Parallelism with MCore"
)
assert not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False), (
"Context Parallelism is not supported with linear CE fusion loss, please set use_linear_ce_fusion_loss to false"
)
def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None:
"""Apply Mixture of Experts configuration."""
model_cfg.expert_tensor_parallel_size = config["megatron_cfg"][
"expert_tensor_parallel_size"
]
model_cfg.expert_model_parallel_size = config["megatron_cfg"][
"expert_model_parallel_size"
]
# MoE stability settings
# Setting moe_router_dtype to higher precision (e.g. fp64) can improve numerical stability,
# especially when using many experts.
model_cfg.moe_router_dtype = config["megatron_cfg"]["moe_router_dtype"]
# The below two configs (and "freeze_moe_router") are used to stabilize moe training
# by preventing updates to the moe router. We found that this is helpful in reducing
# logprob error during training.
# Set this to "none" to disable load balancing loss.
model_cfg.moe_router_load_balancing_type = config["megatron_cfg"][
"moe_router_load_balancing_type"
]
# Set this to 0.0 to disable updates to the moe router expert bias
model_cfg.moe_router_bias_update_rate = config["megatron_cfg"][
"moe_router_bias_update_rate"
]
model_cfg.moe_enable_deepep = config["megatron_cfg"]["moe_enable_deepep"]
model_cfg.moe_token_dispatcher_type = config["megatron_cfg"][
"moe_token_dispatcher_type"
]
model_cfg.moe_shared_expert_overlap = config["megatron_cfg"][
"moe_shared_expert_overlap"
]
model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"]
if "moe_grouped_gemm" in config["megatron_cfg"]:
model_cfg.moe_grouped_gemm = config["megatron_cfg"]["moe_grouped_gemm"]
def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None:
if "mtp_num_layers" in config["megatron_cfg"]:
model_cfg.mtp_num_layers = config["megatron_cfg"]["mtp_num_layers"]
def _apply_precision_config(
model_cfg: Any, config: PolicyConfig, dtype: torch.dtype
) -> None:
"""Apply precision and dtype configuration."""
model_cfg.bf16 = dtype == torch.bfloat16
model_cfg.fp16 = dtype == torch.float16
if model_cfg.fp16:
assert not model_cfg.bf16, "fp16 and bf16 cannot be used together"
model_cfg.params_dtype = torch.float16
elif model_cfg.bf16:
assert not model_cfg.fp16, "fp16 and bf16 cannot be used together"
model_cfg.params_dtype = torch.bfloat16
else:
model_cfg.params_dtype = torch.float32
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
model_cfg.pipeline_dtype = dtype_map[config["megatron_cfg"]["pipeline_dtype"]]
def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None:
"""Apply performance optimization configuration."""
model_cfg.parallel_output = True
# Activation checkpointing
if config["megatron_cfg"]["activation_checkpointing"]:
model_cfg.recompute_granularity = "full"
model_cfg.recompute_method = "uniform"
model_cfg.recompute_num_layers = 1
# Activation function validation
if not model_cfg.gated_linear_unit:
assert model_cfg.activation_func is not None, (
"activation_func must be set if not using gated_linear_unit. This likely "
"indicates an issue in configuration conversion (e.g. activation func was "
"a lambda and couldn't be serialized). This is based on this check "
"https://github.com/NVIDIA/Megatron-LM/blob/1ab876ddc4c1893c76f26d775226a8d1dcdfb3d2/megatron/core/transformer/mlp.py#L174."
)
# Fusion settings
model_cfg.apply_rope_fusion = config["megatron_cfg"]["apply_rope_fusion"]
model_cfg.bias_activation_fusion = config["megatron_cfg"]["bias_activation_fusion"]
# Optional explicit attention backend override for environments where
# TE auto backend probing is unstable.
attention_backend = config["megatron_cfg"].get("attention_backend")
if attention_backend is not None:
for _nvte_var in ("NVTE_FUSED_ATTN", "NVTE_FLASH_ATTN", "NVTE_UNFUSED_ATTN"):
os.environ.pop(_nvte_var, None)
try:
model_cfg.attention_backend = AttnBackend[attention_backend]
except KeyError:
raise ValueError(
f"Invalid attention backend: {attention_backend}. "
f"Available backends are: {list(AttnBackend.__members__.keys())}"
)
# FP8 configuration
fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None)
if fp8_cfg is not None and fp8_cfg.get("enabled", False):
try:
model_cfg.fp8 = fp8_cfg["fp8"]
model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"]
model_cfg.fp8_param = fp8_cfg["fp8_param"]
except KeyError as e:
raise KeyError(f"Missing key in fp8_cfg: {e}")
if model_cfg.fp8_param:
warnings.warn(
"Setting fp8_param=True sometimes causes NaN token_mult_prob_error, please use with caution. "
"Refer to https://github.com/NVIDIA-NeMo/RL/issues/1164 for latest updates with this issue."
)
def _validate_optimizer_config(config: PolicyConfig) -> None:
"""Validate optimizer configuration."""
optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"]
optimizer_offload_fraction = config["megatron_cfg"]["optimizer"][
"optimizer_offload_fraction"
]
if optimizer_cpu_offload:
# Currently, hybrid optimizer (partly on GPU and partly on CPU) is not supported because it conflicts with the way
# Nemo-rl handles the optimizer offload/onload between generation and training. So if using CPU optimizer the offload_fraction should be 1.0.
assert optimizer_offload_fraction == 1.0, (
"Currently for optimizer offloading, only optimizer_offload_fraction=1.0 is supported"
)
def _validate_chunking_config(config: PolicyConfig) -> None:
"""Validate chunking configuration."""
if (
"logprob_chunk_size" in config
and config["logprob_chunk_size"] is not None
and config["logprob_chunk_size"] > 0
):
assert config["megatron_cfg"]["defer_fp32_logits"], (
"defer_fp32_logits must be True if logprob_chunk_size is set"
)
def _create_checkpoint_config(
pretrained_path: str, weights_path: Optional[str], optimizer_path: Optional[str]
) -> CheckpointConfig:
"""Create checkpoint configurations."""
return CheckpointConfig(
save_interval=100,
save=weights_path,
load=weights_path,
load_optim=optimizer_path is not None,
pretrained_checkpoint=pretrained_path,
async_save=False,
fully_parallel_save=True,
fully_parallel_load=True,
load_rng=False,
)
def _validate_training_config(config: PolicyConfig, model_cfg: Any) -> None:
"""Validate training configuration."""
assert "train_iters" in config["megatron_cfg"], (
"train_iters must be set in megatron_cfg. For an example, see "
"https://github.com/NVIDIA-NeMo/RL/blob/bccbc377705a81a1f4b3c31ad9767bcc15f735a8/nemo_rl/algorithms/sft.py#L175-L179."
)
## These settings are required for correct gradient computations in mcore
## when calculate_per_token_loss is True, there is no scaling of the gradient in mcore,
## so we handle the scaling in nemo-rl.
## perform_initialization = True is a workaround to ensure the correct tensor parallel attributes are set
## on the TP-sharded parameters.
model_cfg.calculate_per_token_loss = True
model_cfg.perform_initialization = True
# MoE aux loss validation
assert (
"aux_loss" not in model_cfg.moe_router_load_balancing_type
or model_cfg.moe_aux_loss_coeff == 0
), (
"MoE aux loss is currently not supported due to a known bug in Megatron-LM. "
"See https://github.com/NVIDIA/Megatron-LM/issues/1984 for more details."
)
def _validate_dtype_config(
dtype: torch.dtype, model_cfg: Any, optimizer_cfg: Any
) -> None:
# TODO: this validation should happen inside mbridge: https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1665
if dtype == torch.bfloat16:
assert model_cfg.bf16 == True, (
"policy.megatron_cfg.model.bf16=True must be set if policy.precision=bfloat16. This is handled by nemo-rl so this indicates something is misconfigured."
)
assert (
optimizer_cfg.use_precision_aware_optimizer == False
or optimizer_cfg.bf16 == True
), (
"policy.megatron_cfg.optimizer.bf16=True must be set if policy.precision=bfloat16 when using use_precision_aware_optimizer=True"
)
elif dtype == torch.float16:
assert model_cfg.fp16 == True, (
"policy.megatron_cfg.model.fp16=True must be set if policy.precision=float16. This is handled by nemo-rl so this indicates something is misconfigured."
)
assert (
optimizer_cfg.use_precision_aware_optimizer == False
or optimizer_cfg.fp16 == True
), (
"policy.megatron_cfg.optimizer.fp16=True must be set if policy.precision=float16 when using use_precision_aware_optimizer=True"
)
elif dtype == torch.float32:
assert model_cfg.bf16 == False and model_cfg.fp16 == False, (
"policy.megatron_cfg.model.bf16=False and policy.megatron_cfg.model.fp16=False must be set if policy.precision=float32. This is handled by nemo-rl so this indicates something is misconfigured."
)
assert optimizer_cfg.bf16 == False and optimizer_cfg.fp16 == False, (
"policy.megatron_cfg.optimizer.bf16=False and policy.megatron_cfg.optimizer.fp16=False must be set if policy.precision=float32"
)
def _create_megatron_config(
model_cfg: Any,
checkpoint_config: CheckpointConfig,
config: PolicyConfig,
hf_model_name: str,
dtype: torch.dtype,
) -> ConfigContainer:
"""Create the final Megatron configuration container."""
return ConfigContainer(
model=model_cfg,
checkpoint=checkpoint_config,
logger=LoggerConfig(logging_level=0),
train=TrainingConfig(
micro_batch_size=1, # ignored
global_batch_size=config["train_global_batch_size"], # ignored
train_iters=config["megatron_cfg"]["train_iters"],
),
optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]),
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=config["megatron_cfg"][
"distributed_data_parallel_config"
]["grad_reduce_in_fp32"],
overlap_grad_reduce=config["megatron_cfg"][
"distributed_data_parallel_config"
]["overlap_grad_reduce"],
overlap_param_gather=config["megatron_cfg"][
"distributed_data_parallel_config"
]["overlap_param_gather"],
# we need to set average_in_collective=False with calculate_per_token_loss=T
# otherwise, mcore throws an assertion error.
average_in_collective=False, # Required with calculate_per_token_loss=True
use_distributed_optimizer=config["megatron_cfg"]["optimizer"][
"use_distributed_optimizer"
],
data_parallel_sharding_strategy=config["megatron_cfg"][
"distributed_data_parallel_config"
]["data_parallel_sharding_strategy"],
),
scheduler=SchedulerConfig(**config["megatron_cfg"]["scheduler"]),
dataset=None,
tokenizer=TokenizerConfig(
tokenizer_type="HuggingFaceTokenizer",
tokenizer_model=hf_model_name,
),
)
def _create_draft_pre_wrap_hook(
policy_cfg: PolicyConfig,
megatron_cfg: ConfigContainer,
state: GlobalState,
*,
preload_policy_from_pretrained: bool,
) -> Callable[[list[MegatronModule]], list[MegatronModule]]:
"""Create the hook that attaches draft weights before mixed-precision/DDP wrapping."""
draft_cfg = policy_cfg["draft"]
def draft_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]:
"""Optionally preload the base policy, then attach the draft module to the owner chunk."""
if not draft_cfg["enabled"]:
return model
# Base pretrained checkpoints do not contain draft weights, so load the
# policy weights before attaching the nested draft module.
if preload_policy_from_pretrained:
pretrained_checkpoint = megatron_cfg.checkpoint.pretrained_checkpoint
if pretrained_checkpoint is None or not checkpoint_exists(
pretrained_checkpoint
):
raise ValueError(
f"Invalid pretrained checkpoint directory found: {pretrained_checkpoint}"
)
megatron_cfg.checkpoint.finetune = True
_load_checkpoint_from_path(
load_dir=pretrained_checkpoint,
state=state,
model=model,
optimizer=None,
opt_param_scheduler=None,
checkpointing_context={},
skip_load_to_model_and_opt=False,
ignore_ckpt_step=True,
)
draft_owner = find_draft_owner_chunk(model)
if draft_owner is None:
return model
if getattr(draft_owner, "draft_model", None) is not None:
raise RuntimeError(
"Policy model chunk already has an attached `draft_model`."
)
pg_collection = get_pg_collection(model)
draft_model = build_draft_model(
megatron_cfg.model,
draft_config=draft_cfg,
pg_collection=pg_collection,
policy_model_chunk=draft_owner,
)
if draft_model is not None:
setattr(draft_owner, "draft_model", draft_model)
return model
return draft_pre_wrap_hook
def setup_model_and_optimizer(
policy_cfg: PolicyConfig,
megatron_cfg: ConfigContainer,
load_optimizer: bool = True,
get_embedding_ranks=None, # TODO @sahilj: What is this?
get_position_embedding_ranks=None,
):
state = GlobalState()
state.cfg = megatron_cfg
# TODO: Freeze state.cfg
megatron_cfg.dist.external_gpu_device_mapping = True
initialize_megatron(
cfg=megatron_cfg,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
)
if megatron_cfg.ft and megatron_cfg.ft.enable_ft_package:
fault_tolerance.setup(megatron_cfg, state)
fault_tolerance.maybe_setup_simulated_fault(megatron_cfg.ft)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options(megatron_cfg.model, megatron_cfg.train.micro_batch_size)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
start_time_tensor = torch.tensor(
[state.start_time], dtype=torch.double, device="cuda"
)
torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
state.start_time = start_time_tensor.item()
print(
"time to initialize megatron (seconds): {:.3f}".format(
time.time() - state.start_time
)
)
torch.distributed.barrier()
# Context used for persisting some state between checkpoint saves.
checkpointing_context = init_checkpointing_context(megatron_cfg.checkpoint)
# Tokenizer
if megatron_cfg.tokenizer.hf_tokenizer_kwargs is None:
megatron_cfg.tokenizer.hf_tokenizer_kwargs = {}
megatron_cfg.tokenizer.hf_tokenizer_kwargs["trust_remote_code"] = True
megatron_cfg.tokenizer.hf_tokenizer_kwargs["use_fast"] = True
build_tokenizer(
megatron_cfg.tokenizer,
make_vocab_size_divisible_by=megatron_cfg.model.make_vocab_size_divisible_by
// megatron_cfg.model.tensor_model_parallel_size,
tensor_model_parallel_size=megatron_cfg.model.tensor_model_parallel_size,
)
assert megatron_cfg.model.vocab_size, "vocab size must be specified in model config"
torch.distributed.barrier()
pre_wrap_hook = []
use_peft = policy_cfg["megatron_cfg"].get("peft", {}).get("enabled", False)
draft_enabled = "draft" in policy_cfg and policy_cfg["draft"]["enabled"]
resume_checkpoint_exists = (
megatron_cfg.checkpoint.load is not None
and checkpoint_exists(megatron_cfg.checkpoint.load)
)
pretrained_checkpoint_exists = (
megatron_cfg.checkpoint.pretrained_checkpoint is not None
and checkpoint_exists(megatron_cfg.checkpoint.pretrained_checkpoint)
)
preload_policy_from_pretrained_for_draft = (
draft_enabled
and not use_peft # The PEFT pre-wrap hook loads the pretrained base policy before adapters are attached.
and not resume_checkpoint_exists # Resume checkpoints already carry the attached draft module state.
and pretrained_checkpoint_exists
)
mixed_precision_wrapper = Float16Module
if policy_cfg["megatron_cfg"]["freeze_moe_router"]:
def freeze_moe_router(megatron_model):
if not isinstance(megatron_model, list):
megatron_model = [megatron_model]
for model_module in megatron_model:
# Handle both wrapped (Float16Module) and unwrapped models
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
if hasattr(layer, "mlp") and hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
mixed_precision_wrapper = MoEFloat16Module
pre_wrap_hook.extend([freeze_moe_router])
if use_peft:
peft_cfg = policy_cfg["megatron_cfg"].get("peft", {})
if "dim" not in peft_cfg or peft_cfg["dim"] is None:
raise ValueError(
"If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
)
if "alpha" not in peft_cfg or peft_cfg["alpha"] is None:
raise ValueError(
"If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
)
peft = LoRA(
target_modules=peft_cfg["target_modules"],
exclude_modules=peft_cfg["exclude_modules"],
dim=peft_cfg["dim"],
alpha=peft_cfg["alpha"],
dropout=peft_cfg["dropout"],
dropout_position=peft_cfg["dropout_position"],
lora_A_init_method=peft_cfg["lora_A_init_method"],
lora_B_init_method=peft_cfg["lora_B_init_method"],
a2a_experimental=peft_cfg["a2a_experimental"],
lora_dtype=peft_cfg["lora_dtype"],
)
else:
peft = None
megatron_cfg.peft = peft
if megatron_cfg.peft is not None:
pre_peft_hook = _create_peft_pre_wrap_hook(megatron_cfg, state)
megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook)
def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
model = pre_peft_hook(model)
return model
pre_wrap_hook.extend([composed_peft_hook])
if draft_enabled:
draft_pre_wrap_hook = _create_draft_pre_wrap_hook(
policy_cfg,
megatron_cfg,
state,
preload_policy_from_pretrained=preload_policy_from_pretrained_for_draft,
)
pre_wrap_hook.extend([draft_pre_wrap_hook])
# Model, optimizer, and learning rate.
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
setattr(megatron_cfg.model, "_pg_collection", pg_collection)
if policy_cfg["megatron_cfg"].get("use_linear_ce_fusion_loss", False):
patch_gpt_model_forward_for_linear_ce_fusion(
chunk_size=policy_cfg["megatron_cfg"]["linear_ce_fusion_chunk_size"]
)
model = get_model(
megatron_cfg.model,
megatron_cfg.ddp,
use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2,
overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step,
data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init,
pre_wrap_hook=pre_wrap_hook,
mixed_precision_wrapper=mixed_precision_wrapper,
pg_collection=pg_collection,
)
if load_optimizer:
optimizer, scheduler = setup_optimizer(
optimizer_config=megatron_cfg.optimizer,
scheduler_config=megatron_cfg.scheduler,
model=model,
use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups,
)
else:
optimizer = None
scheduler = None
print("Model, optimizer, and learning rate scheduler built")
torch.distributed.barrier()
if megatron_cfg.peft is not None:
should_load_checkpoint = resume_checkpoint_exists
if should_load_checkpoint:
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
# This is switched off here in order to load these states from the checkpoint
megatron_cfg.checkpoint.finetune = False
else:
should_load_checkpoint = resume_checkpoint_exists or (
pretrained_checkpoint_exists
and not preload_policy_from_pretrained_for_draft
)
# Load checkpoint if applicable
if should_load_checkpoint:
load_checkpoint(
state,
model,
optimizer,
scheduler,
checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2,
)
print("Checkpoint loaded")
torch.distributed.barrier()
draft_model = get_attached_draft_model(model)
# Set the param sync function for the model
param_sync_func = None
if megatron_cfg.ddp.overlap_param_gather and megatron_cfg.ddp.align_param_gather:
param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
if len(model) == 1:
param_sync_func = param_sync_func[0]
# Get the first model from the list
model = model[0]
return ModelAndOptimizerState(
state,
model,
optimizer,
scheduler,
checkpointing_context,
param_sync_func,
draft_model=draft_model,
)
def handle_model_import(
config: PolicyConfig,
hf_model_name: str,
pretrained_path: str,
pt_checkpoint_exists: bool,
) -> None:
"""Handle HF model import if checkpoint doesn't exist."""
force_reconvert_from_hf = config["megatron_cfg"].get(
"force_reconvert_from_hf", False
)
if pt_checkpoint_exists and not force_reconvert_from_hf:
print(f"Checkpoint already exists at {pretrained_path}. Skipping import.")
else:
hf_config_overrides = config.get("hf_config_overrides", {}) or {}
import_model_from_hf_name(
hf_model_name,
pretrained_path,
config["megatron_cfg"],
**hf_config_overrides,
)
if parallel_state.model_parallel_is_initialized():
print("Reinitializing model parallel after loading model state.")
parallel_state.destroy_model_parallel()
def setup_reference_model_state(
config: PolicyConfig, megatron_cfg: ConfigContainer, pretrained_path: str
) -> dict:
"""Setup the reference model for inference and return its state dict."""
# Create reference checkpoint config
ref_checkpoint_config = CheckpointConfig(
pretrained_checkpoint=pretrained_path,
save=None,
load=None,
fully_parallel_load=True,
load_rng=False,
)
ref_ckpt_context = init_checkpointing_context(ref_checkpoint_config)
# Create a separate megatron config for the reference model
ref_megatron_cfg = ConfigContainer(
model=megatron_cfg.model,