Skip to content

save_megatron_model deadlocks during HF-to-Megatron checkpoint conversion (fully_parallel_save) #2225

@nic-nvidia

Description

@nic-nvidia

Summary

save_megatron_model deadlocks during the one-time HF-to-Megatron weight conversion when using the Megatron backend with non-colocated vLLM generation. The root cause is FullyParallelSaveStrategyWrapper, which is activated by CheckpointConfig.fully_parallel_save defaulting to True.

The deadlock occurs because FullyParallelSaveStrategyWrapper.apply_saving_parallelization() calls all_gather_object on DP sub-groups derived from the full torch.distributed world. In non-colocated mode, this world includes vLLM inference workers that never enter save_megatron_model, so they never participate in the collective -- causing a permanent hang.

The fix is to pass fully_parallel_save=False from community_import.py when calling bridge.save_megatron_model(). Megatron-Bridge already exposes this parameter.

Call Stack

Level Call Issue
worker_groups.py:455 Sets WORLD_SIZE env var Includes all GPUs (training + inference)
setup.py:190 init_process_group("nccl") Default PG locked to full world
community_import.py:83 initialize_model_parallel(seed=0) All sub-groups (TP/DP/CP) include vLLM ranks
community_import.py:110 bridge.save_megatron_model() Only training ranks call this
checkpointing.py:674 dist_checkpointing.save() fully_parallel_save=True (default) activates FullyParallelSaveStrategyWrapper
fully_parallel.py:121 determine_main_replica_uniform_distribution() DP sub-group includes vLLM ranks
exchange_utils.py:213 all_gather_object(group=dp_subgroup) DEADLOCK -- vLLM ranks never enter

Reproduction

  • Model: Nemotron 3 Super 120B with Megatron backend (TP=4, CP=4, ETP=4, EMP=4)
  • Config: Non-colocated vLLM generation (colocated.enabled: false)
  • Cluster: Multi-node (tested on 3x8 and 6x8 DGX B200, KubeRay 1.6.0)
  • Container: nvcr.io/nvidia/nemo-rl:v0.5.0.nemotron_3_super
  • Key condition: No cached Megatron checkpoint (first run or cache cleared)

On subsequent runs, the cached checkpoint is found and save_megatron_model is skipped entirely -- which is likely why this has not been caught previously.

Evidence

Using a collective tracing tool that wraps all torch.distributed collectives:

[COLL] rank=8 ENTER all_gather_object | size=4 ranks=[0, 4, 8, 12]
       | exchange_utils.py:213 <- fully_parallel.py:121 <- serialization.py:429
[CRUMB] stuck at distributed_c10d.py:3180 all_gather_object (indefinitely)

All training ranks enter the collective but none exit. Confirmed across 12 test runs.

Fix

One-line change in community_import.py:

# Before:
bridge.save_megatron_model(megatron_model, output_path)

# After:
bridge.save_megatron_model(megatron_model, output_path, fully_parallel_save=False)

fully_parallel_save is already exposed as a parameter in Megatron-Bridge's save_megatron_model and AutoBridge.save_megatron_model. NeMo RL just needs to pass it.

Validated: 120B checkpoint saved successfully (242GB, 34 .distcp shards) with this fix on a 3x8 B200 cluster.

Context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions