Skip to content

Draft: compatible with Megatron-FSDP TP#2299

Open
conver334 wants to merge 22 commits intoNVIDIA-NeMo:mainfrom
conver334:new_fsdp
Open

Draft: compatible with Megatron-FSDP TP#2299
conver334 wants to merge 22 commits intoNVIDIA-NeMo:mainfrom
conver334:new_fsdp

Conversation

@conver334
Copy link

@conver334 conver334 commented Feb 10, 2026

What does this PR do ?

In #1910, converting models between HuggingFace and Megatron-FSDP formats is not supported when TP (Tensor Parallel) is enabled.

The issue was caused by incorrect detection of TP mode for MCore model parameters in M-FSDP under certain conditions.

PR 3161PR 3191PR 3287 will fix this problem. After updating, the interface gather_uneven_dtensor_to_full_tensor will be renamed to uneven_dtensor_to_full_tensor, and its return type will be changed from DTensor to Tensor. This PR updates the interface usage accordingly to stay compatible with those changes.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to

cc: @ISEEKYAN @shjwudp

Summary by CodeRabbit

Release Notes

  • New Features

    • Added comprehensive examples for converting Hugging Face models to Megatron FSDP format with round-trip validation and text generation capabilities
    • Enhanced weight conversion with improved distributed tensor handling and automatic model wrapper detection
  • Tests

    • Added functional test suite for validating Hugging Face to Megatron FSDP conversions

conver334 and others added 22 commits January 19, 2026 03:24
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
Signed-off-by: conver334 <conver334@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@conver334 conver334 changed the title Draft: prepare for new version of Megatron-FSDP Draft: compatible with Megatron-FSDP TP Feb 10, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 10, 2026

📝 Walkthrough

Walkthrough

Introduces Megatron FSDP (Fully Sharded Data Parallel) integration with Hugging Face models through conversion scripts and supporting infrastructure. Adds round-trip conversion workflows, text generation capabilities, and DTensor-aware weight handling to bridge HF model formats with distributed Megatron FSDP models.

Changes

Cohort / File(s) Summary
Conversion Examples
examples/conversion/hf_fsdp_roundtrip.py, examples/conversion/hf_to_megtron_fsdp_generate_text.py
Two new scripts implementing HF-to-Megatron FSDP roundtrip workflows. First validates weight equivalence during conversion; second performs text generation with rank-aware distribution and token emission. Both include multi-precision parallelism configuration (TP, CP, EP), synchronization utilities, and CLI interfaces for model selection and parallelism knobs.
Core Conversion Logic
src/megatron/bridge/models/conversion/model_bridge.py
Adds public unwrap_model() function to handle FSDP-wrapped model unwrapping. Integrates FSDP detection and DTensor-aware weight handling in load_weights_hf_to_megatron, stream_weights_hf_to_megatron, and stream_weights_megatron_to_hf functions. Enables conditional unwrapping, weight slicing for FSDP shards, and re-wrapping post-conversion.
Parameter Mapping Logic
src/megatron/bridge/models/conversion/param_mapping.py
Adds DTensor detection via _module_uses_fsdp() helper. Updates scatter/gather pathways across multiple mapping classes (ColumnParallel, RowParallel, GatedMLP, QKV, etc.) to compute output_shape correctly for DTensor targets. Extends megatron_to_hf conditional logic to treat FSDP modules analogously to tp_size==1 scenarios.
Test Suite
tests/functional_tests/converter/test_hf_fsdp_conversion.py
New functional test validating HF-to-Megatron FSDP roundtrip. Includes deepseek_toy_model_path fixture for lightweight in-tree model setup and test_hf_fsdp_roundtrip parametrized test that spawns distributed runs, verifies return codes, and asserts presence of converted model artifacts (config.json, weight files).

Sequence Diagram

sequenceDiagram
    participant HF as HuggingFace Model
    participant Bridge as HF-Megatron Bridge
    participant Provider as Megatron Model Provider
    participant FSDP as FSDP Model
    participant Gen as Generation Loop
    participant Tokenizer as Tokenizer
    
    HF->>Bridge: Load via AutoBridge.from_hf_pretrained()
    Bridge->>Provider: Create model provider & configure TP/CP/EP
    Provider->>FSDP: Initialize distributed Megatron model
    Bridge->>FSDP: Load HF weights into distributed model
    FSDP->>FSDP: Move to CUDA & set eval mode
    HF->>Tokenizer: Load tokenizer from HF model
    
    loop For each generation step
        Gen->>FSDP: Forward pass via get_forward_backward_func()
        FSDP->>Gen: Return logits (last pipeline stage)
        Gen->>Gen: All-gather outputs across TP world
        Gen->>Gen: Select next token via argmax
        Gen->>Gen: Broadcast token IDs to all ranks
        Gen->>Gen: Update input_ids & position_ids
        Gen->>Tokenizer: Check if EOS token generated
    end
    
    FSDP-->>HF: Save converted model to HF format
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

Run CICD

Suggested reviewers

  • yaoyu-33
🚥 Pre-merge checks | ✅ 1 | ❌ 3
❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.17% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major FSDP support, DTensor handling, and weight conversion logic changes, but PR description lacks documentation of test execution, validation results, or numerical equivalence checks required for such critical modifications. Document test execution results, numerical validation evidence, regression testing outcomes, and resolution of critical review comments before marking PR ready for review.
Title check ❓ Inconclusive The title is vague and overly broad, using non-descriptive phrasing 'compatible with' without clearly summarizing the main changes. Revise the title to be more specific and descriptive. Example: 'Add Megatron-FSDP support for HF model conversion with TP' or 'Implement HF-to-Megatron FSDP round-trip conversion with tensor parallelism.'
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/models/conversion/param_mapping.py (1)

862-870: ⚠️ Potential issue | 🟡 Minor

Add defensive comment clarifying that _module_uses_fsdp(None) is safe and explain the None case on non-owning PP ranks.

The FSDP bypass assumption is correct for the primary call path in stream_weights_megatron_to_hf (model_bridge.py lines 1007-1012), which ensures DTensor materialization before calling megatron_to_hf. The adapter weight paths (peft_bridge.py) use regular tensor weights from adapters, not DTensors, so the FSDP bypass doesn't apply there.

However, _module_uses_fsdp(megatron_module) at line 865 lacks documentation: while hasattr(None, "_parameters") safely returns False, making the function handle None correctly, this behavior should be documented with a comment explaining that on non-owning PP ranks (where megatron_module is None), the bypass safely falls through to gather logic. This clarifies the contract for future readers.

🤖 Fix all issues with AI agents
In `@examples/conversion/hf_fsdp_roundtrip.py`:
- Around line 178-180: Remove the redundant explicit process-group teardown:
delete the torch.distributed.is_initialized() check and the call to
torch.distributed.destroy_process_group() at the end of main(), because main()
is already decorated with `@torchrun_main` which handles process group
destruction; locate the cleanup using the symbol
torch.distributed.destroy_process_group (and main()) and remove those two lines
so the process group is only torn down once.

In `@examples/conversion/hf_to_megtron_fsdp_generate_text.py`:
- Around line 252-254: The explicit teardown calling
torch.distributed.destroy_process_group() after checking
torch.distributed.is_initialized() duplicates cleanup performed by the
`@torchrun_main` decorator and can cause double-destroy errors; remove the
explicit torch.distributed.destroy_process_group() call (or guard it behind
logic that detects whether `@torchrun_main` managed the group) so only one
teardown path runs, ensuring the code relies on `@torchrun_main` to manage process
group lifecycle instead of calling destroy_process_group() unconditionally.
- Line 165: The example prompt string assigned to the variable prompt contains a
typo ("reforcement"); update the prompt in hf_to_megtron_fsdp_generate_text.py
by changing the value of prompt from "what is reforcement learning?" to "what is
reinforcement learning?" so the example query uses the correct spelling.
- Line 1: Rename the script file named "hf_to_megtron_fsdp_generate_text.py" to
"hf_to_megatron_fsdp_generate_text.py" (insert the missing "a" in "megatron")
and update any references to that filename across the repo (examples index,
README, CI job, import or invocation sites) so they point to the new correct
name; verify the module is importable under the new name and adjust any script
entrypoints or tests that reference hf_to_megtron_fsdp_generate_text.py.

In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 815-818: The DTensor branch in the weight conversion loop skips
the subsequent shape-compatibility checks because of the early continue; update
the DTensor handling (where task.param_weight is a DTensor and you use
task.param_weight.megatron_fsdp_slice to index converted_weights and copy into
task.param_weight._local_tensor) to first validate that
sliced_converted_weights.shape matches
task.param_weight._local_tensor.reshape(-1).shape (or expected local shard
shape) and raise/assert with a clear message on mismatch before performing the
copy_. Preserve the existing reshape and copy_ logic but add this defensive
assertion/check immediately after computing sliced_converted_weights and before
calling copy_, removing the silent-risk continue behavior.
- Around line 985-998: The code only sets unwrapped_model_list inside the if
use_megatron_fsdp branch, causing NameError when use_megatron_fsdp is False;
ensure unwrapped_model_list is defined for both branches by setting it to the
original megatron_model when not using FSDP (i.e., after computing
use_megatron_fsdp, if True call unwrap_model(megatron_model) else assign
unwrapped_model_list = megatron_model), then continue to call
build_conversion_tasks and build_adapter_conversion_tasks using
unwrapped_model_list and set unwrapped_model = unwrapped_model_list[0]; update
references to use unwrapped_model_list/unwrapped_model accordingly.
- Around line 848-851: The loop invoking install_optimized_model_weights on
original_megatron_model when use_megatron_fsdp is true should defensively check
for the method before calling it; update the block that iterates over
original_megatron_model (inside the use_megatron_fsdp branch) to either use
hasattr(m.module, "install_optimized_model_weights") before calling or wrap the
call in a try/except AttributeError so models without
install_optimized_model_weights are skipped and do not raise; keep the rest of
the return behavior unchanged.

In `@src/megatron/bridge/models/conversion/param_mapping.py`:
- Around line 2076-2080: The torch.chunk call that splits megatron_weights
(torch.chunk(megatron_weights, self.tp_size, dim=0)) can produce unequal shards
if megatron_weights.size(0) is not divisible by self.tp_size; add a defensive
check before that line (inside the same method where _module_uses_fsdp is used
and gather_from_tp_ranks is called) to assert megatron_weights.size(0) %
self.tp_size == 0 and raise/abort with a clear message including
megatron_weights.size(0) and self.tp_size so the conversion fails loudly instead
of producing unequal shards that corrupt GatedMLPMapping gate/split logic.

In `@tests/functional_tests/converter/test_hf_fsdp_conversion.py`:
- Around line 98-99: The test contains a duplicated call to
model.save_pretrained(model_dir, safe_serialization=True); remove the redundant
second invocation so the model is only saved once. Locate the duplicate calls to
model.save_pretrained in test_hf_fsdp_conversion.py (within the test function
where the model is prepared and saved) and delete the extra line, leaving a
single model.save_pretrained(model_dir, safe_serialization=True) call.
- Around line 145-154: Replace the use of "assert False" with pytest.fail and
add a timeout to the subprocess.run call to avoid hangs; specifically, when
running the conversion subprocess (the cmd variable passed into subprocess.run)
add a timeout argument (e.g., timeout=XXX) and on non-zero return use
pytest.fail(...) instead of assert False, including result.returncode,
result.stdout and result.stderr in the failure message so the test prints useful
diagnostics (locate the subprocess.run call and the subsequent if
result.returncode != 0 block referencing result and cmd).
🧹 Nitpick comments (6)
src/megatron/bridge/models/conversion/model_bridge.py (1)

176-200: unwrap_model lacks type hints and has a minimal docstring.

Per coding guidelines, functions should use type hints for arguments and return types, and use Google-style docstrings.

Proposed improvement
-def unwrap_model(model, module_instances=None):
-    """Unwrap_model to return the final model instance"""
+def unwrap_model(
+    model: nn.Module | list[nn.Module],
+    module_instances: tuple[type, ...] | None = None,
+) -> nn.Module | list[nn.Module]:
+    """Unwrap a model by stripping DDP / FSDP / Float16Module wrappers.
+
+    Args:
+        model: A single module or list of modules to unwrap.
+        module_instances: Tuple of wrapper types to strip. If ``None``,
+            a default set of known Megatron wrappers is used.
+
+    Returns:
+        The innermost module(s) after removing all wrapper layers.
+    """

As per coding guidelines: "Use type hints for function arguments and return types" and "Use Google style docstrings."

examples/conversion/hf_fsdp_roundtrip.py (2)

117-117: Triple-quoted strings used as inline comments within a function body.

Lines 117 and 145 use triple-quoted strings (which are evaluated as no-op expressions) instead of # comments. This is unconventional and may confuse linters or readers.

Proposed fix
-    """Export Megatron-FSDP model to HuggingFace format and verify the weights"""
+    # Export Megatron-FSDP model to HuggingFace format and verify the weights
     table = Table(title="Hugging Face Weights Verification")
...
-    """Save Megatron-FSDP model to HuggingFace format"""
+    # Save Megatron-FSDP model to HuggingFace format
     if _is_rank_zero():

As per coding guidelines: "Reserve comments for code within a function or interfaces that are local to a file."

Also applies to: 145-145


16-24: Usage example in docstring uses python -m torch.distributed.run instead of uv run.

As per coding guidelines for examples/**/*.py: "Use 'uv run' to execute scripts instead of activating a virtual environment and calling 'python' directly."

tests/functional_tests/converter/test_hf_fsdp_conversion.py (2)

127-132: Use sys.executable instead of "python" for subprocess invocation.

Hardcoding "python" may pick up a different interpreter than the test is running under, leading to import errors or environment mismatches.

Proposed fix
         cmd = [
-            "python",
+            sys.executable,
             "-m",
             "torch.distributed.run",

91-95: Silent exception swallowing when saving tokenizer.

The bare except Exception: pass hides real errors (e.g., disk full, permission denied). At minimum, log the failure.

Proposed fix
         try:
             tokenizer = AutoTokenizer.from_pretrained("gpt2")
             tokenizer.save_pretrained(model_dir)
-        except Exception:
-            pass
+        except Exception as e:
+            print(f"WARNING: Could not save tokenizer: {e}")
examples/conversion/hf_to_megtron_fsdp_generate_text.py (1)

83-116: Duplicated utility functions across example scripts.

_is_rank_zero, _maybe_barrier, _get_world_size, and _configure_model_provider are identical between this file and hf_fsdp_roundtrip.py. Consider extracting them into a shared module under examples/conversion/ or megatron.bridge.utils.

Comment on lines +178 to +180

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, examine the file structure around main() and the destroy_process_group call
cat -n examples/conversion/hf_fsdp_roundtrip.py | head -200 | tail -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2338


🏁 Script executed:

# Search for the main function and `@torchrun_main` decorator
rg -n "@torchrun_main|def main" examples/conversion/hf_fsdp_roundtrip.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 101


🏁 Script executed:

# Find where `@torchrun_main` is defined
rg -n "def torchrun_main|@torchrun_main" --type=py -B2 -A10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 15504


🏁 Script executed:

# Read the complete torchrun_main implementation
cat -n src/megatron/bridge/models/decorators/torchrun.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2365


🏁 Script executed:

# Verify by checking if there are any other destroy_process_group calls in the file
rg -n "destroy_process_group" examples/conversion/hf_fsdp_roundtrip.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 124


Remove redundant process group cleanup at lines 179–180.

main() is decorated with @torchrun_main, which automatically destroys the process group on successful completion (see src/megatron/bridge/models/decorators/torchrun.py:38–39). The explicit torch.distributed.destroy_process_group() call here is redundant and will cause a second destruction attempt, resulting in an error or warning.

Proposed fix
     main(
         hf_model_id=args.hf_model_id,
         output_dir=args.output_dir,
         tp=args.tp,
         cp=args.cp,
         ep=args.ep,
         trust_remote_code=args.trust_remote_code,
     )
-
-    if torch.distributed.is_initialized():
-        torch.distributed.destroy_process_group()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
main(
hf_model_id=args.hf_model_id,
output_dir=args.output_dir,
tp=args.tp,
cp=args.cp,
ep=args.ep,
trust_remote_code=args.trust_remote_code,
)
🤖 Prompt for AI Agents
In `@examples/conversion/hf_fsdp_roundtrip.py` around lines 178 - 180, Remove the
redundant explicit process-group teardown: delete the
torch.distributed.is_initialized() check and the call to
torch.distributed.destroy_process_group() at the end of main(), because main()
is already decorated with `@torchrun_main` which handles process group
destruction; locate the cleanup using the symbol
torch.distributed.destroy_process_group (and main()) and remove those two lines
so the process group is only torn down once.

@@ -0,0 +1,254 @@
#!/usr/bin/env python3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Filename typo: hf_to_megtron_fsdp_generate_text.pyhf_to_megatron_fsdp_generate_text.py.

Missing "a" in "megatron". This will make the script harder to discover and is inconsistent with naming elsewhere.

🤖 Prompt for AI Agents
In `@examples/conversion/hf_to_megtron_fsdp_generate_text.py` at line 1, Rename
the script file named "hf_to_megtron_fsdp_generate_text.py" to
"hf_to_megatron_fsdp_generate_text.py" (insert the missing "a" in "megatron")
and update any references to that filename across the repo (examples index,
README, CI job, import or invocation sites) so they point to the new correct
name; verify the module is importable under the new name and adjust any script
entrypoints or tests that reference hf_to_megtron_fsdp_generate_text.py.

Comment on lines +252 to +254

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential double destruction of the process group.

Same issue as in hf_fsdp_roundtrip.py: @torchrun_main likely already destroys the process group, so the explicit teardown at line 253–254 may cause errors.

🤖 Prompt for AI Agents
In `@examples/conversion/hf_to_megtron_fsdp_generate_text.py` around lines 252 -
254, The explicit teardown calling torch.distributed.destroy_process_group()
after checking torch.distributed.is_initialized() duplicates cleanup performed
by the `@torchrun_main` decorator and can cause double-destroy errors; remove the
explicit torch.distributed.destroy_process_group() call (or guard it behind
logic that detects whether `@torchrun_main` managed the group) so only one
teardown path runs, ensuring the code relies on `@torchrun_main` to manage process
group lifecycle instead of calling destroy_process_group() unconditionally.

Comment on lines +815 to +818
if isinstance(task.param_weight, DTensor):
sliced_converted_weights = converted_weights.reshape(-1)[task.param_weight.megatron_fsdp_slice]
task.param_weight._local_tensor.reshape(-1).copy_(sliced_converted_weights)
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

DTensor weight copy skips shape validation.

The continue on line 818 bypasses the shape-compatibility check (lines 820–845) that exists for regular tensors. A shape mismatch between converted_weights and the DTensor's expected local shard would silently corrupt the model. Consider adding a defensive assertion before the copy.

Proposed fix
                 if isinstance(task.param_weight, DTensor):
+                    expected_size = task.param_weight._local_tensor.reshape(-1).numel()
                     sliced_converted_weights = converted_weights.reshape(-1)[task.param_weight.megatron_fsdp_slice]
+                    if sliced_converted_weights.numel() != expected_size:
+                        raise ValueError(
+                            f"DTensor shape mismatch for {task.mapping.megatron_param}: "
+                            f"expected {expected_size} elements, got {sliced_converted_weights.numel()}"
+                        )
                     task.param_weight._local_tensor.reshape(-1).copy_(sliced_converted_weights)
                     continue

Based on learnings: "Always validate tensor shapes before copying weights in weight conversion".

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 815 -
818, The DTensor branch in the weight conversion loop skips the subsequent
shape-compatibility checks because of the early continue; update the DTensor
handling (where task.param_weight is a DTensor and you use
task.param_weight.megatron_fsdp_slice to index converted_weights and copy into
task.param_weight._local_tensor) to first validate that
sliced_converted_weights.shape matches
task.param_weight._local_tensor.reshape(-1).shape (or expected local shard
shape) and raise/assert with a clear message on mismatch before performing the
copy_. Preserve the existing reshape and copy_ logic but add this defensive
assertion/check immediately after computing sliced_converted_weights and before
calling copy_, removing the silent-risk continue behavior.

Comment on lines +848 to +851
if use_megatron_fsdp:
for m in original_megatron_model:
m.module.install_optimized_model_weights()
return original_megatron_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find where install_optimized_model_weights is defined
rg -n "def install_optimized_model_weights" --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check the context of lines 848-851 in model_bridge.py
cat -n src/megatron/bridge/models/conversion/model_bridge.py | sed -n '840,860p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1212


🏁 Script executed:

#!/bin/bash
# Find what types of objects are in original_megatron_model
rg -B 5 -A 5 "original_megatron_model" src/megatron/bridge/models/conversion/model_bridge.py | head -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1101


🏁 Script executed:

#!/bin/bash
# Search for any hasattr or try-except guards around line 848-851
rg -B 10 -A 10 "install_optimized_model_weights" src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1065


🏁 Script executed:

#!/bin/bash
# Search for similar method names that might be alternatives
rg -n "optimized.*weight|install.*weight" --type=py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 187


🏁 Script executed:

#!/bin/bash
# Check what FullyShardedDataParallel wraps and what unwrap_model returns
rg -n "def unwrap_model|class FullyShardedDataParallel" --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 176


🏁 Script executed:

#!/bin/bash
# Search for any hasattr checks related to model methods
rg -n "hasattr.*module.*install|hasattr.*optimized" --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Look at the unwrap_model function to understand what it returns
cat -n src/megatron/bridge/models/conversion/model_bridge.py | sed -n '176,210p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1895


🏁 Script executed:

#!/bin/bash
# Search for any model class definitions that might have this method
rg -n "class.*Model.*:" src/megatron/bridge/models/ --type=py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2109


🏁 Script executed:

#!/bin/bash
# Check if install_optimized_model_weights might be expected to be imported
rg -n "import.*install_optimized_model_weights|from.*install_optimized_model_weights" --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Search for any recent additions or TODOs related to this
rg -n "TODO|FIXME|XXX.*optim|XXX.*weight" src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 224


🏁 Script executed:

#!/bin/bash
# Check if install_optimized_model_weights might be defined in megatron.core
find . -type f -name "*.py" -exec grep -l "def install_optimized_model_weights" {} \;

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check the broader context of the FSDP handling - see how models are created
rg -B 20 "use_megatron_fsdp = isinstance" src/megatron/bridge/models/conversion/model_bridge.py | head -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1664


🏁 Script executed:

#!/bin/bash
# Check if there's a pattern of hasattr checks elsewhere in the file for optional methods
rg -n "hasattr.*module" src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 228


🏁 Script executed:

#!/bin/bash
# Check if this method might be expected to be in megatron.core or another external library
rg -n "install_optimized_model_weights" . --type=py 2>/dev/null | wc -l

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 72


🏁 Script executed:

#!/bin/bash
# Look at what type m.module would be - check base model classes
rg -n "class.*GPT.*Model|class.*T5.*Model" src/megatron/bridge/models/ --type=py | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1121


🏁 Script executed:

#!/bin/bash
# Check git history to see if this is newly added (last few commits)
git log -10 --oneline src/megatron/bridge/models/conversion/model_bridge.py 2>/dev/null || echo "Git history not available"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 110


🏁 Script executed:

#!/bin/bash
# Check the entire git diff or blame for this line
git log -p -S "install_optimized_model_weights" -- src/megatron/bridge/models/conversion/model_bridge.py 2>/dev/null | head -50 || echo "Git details not available"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check if there's any pattern of methods that models are expected to have
rg -n "m\.module\." src/megatron/bridge/models/conversion/model_bridge.py | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 133


🏁 Script executed:

#!/bin/bash
# Look at the specific commit to understand context
git show 4ac2961 --stat 2>/dev/null | head -20 || echo "Commit details not available"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1022


🏁 Script executed:

#!/bin/bash
# Check if there's a type annotation or comment explaining what m.module should be
cat -n src/megatron/bridge/models/conversion/model_bridge.py | sed -n '845,860p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 875


Add a defensive check before calling install_optimized_model_weights().

The method install_optimized_model_weights() is not defined anywhere in the codebase and is being called unconditionally at line 850. This will raise AttributeError if a model variant doesn't implement this method. Add a hasattr check (consistent with the pattern used elsewhere in this file, e.g., line 794) or wrap in a try-except block to handle models that don't have this method.

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 848 -
851, The loop invoking install_optimized_model_weights on
original_megatron_model when use_megatron_fsdp is true should defensively check
for the method before calling it; update the block that iterates over
original_megatron_model (inside the use_megatron_fsdp branch) to either use
hasattr(m.module, "install_optimized_model_weights") before calling or wrap the
call in a try/except AttributeError so models without
install_optimized_model_weights are skipped and do not raise; keep the rest of
the return behavior unchanged.

Comment on lines +2076 to +2080
if _module_uses_fsdp(megatron_module):
gathered_shards = torch.chunk(megatron_weights, self.tp_size, dim=0)
else:
# Gather shards from all TP ranks
gathered_shards = self.gather_from_tp_ranks(megatron_weights)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

GatedMLPMapping: torch.chunk(megatron_weights, self.tp_size, dim=0) relies on the full tensor being exactly tp_size-chunked.

If the full FSDP-materialized tensor's dim 0 is not evenly divisible by tp_size, torch.chunk will produce unequal splits, which would silently corrupt the gate/up splitting downstream. Consider adding a divisibility assertion.

Proposed defensive check
             if _module_uses_fsdp(megatron_module):
+                assert megatron_weights.shape[0] % self.tp_size == 0, (
+                    f"FSDP full tensor dim 0 ({megatron_weights.shape[0]}) must be divisible by tp_size ({self.tp_size})"
+                )
                 gathered_shards = torch.chunk(megatron_weights, self.tp_size, dim=0)

Based on learnings: "Always validate tensor shapes before copying weights in weight conversion".

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/param_mapping.py` around lines 2076 -
2080, The torch.chunk call that splits megatron_weights
(torch.chunk(megatron_weights, self.tp_size, dim=0)) can produce unequal shards
if megatron_weights.size(0) is not divisible by self.tp_size; add a defensive
check before that line (inside the same method where _module_uses_fsdp is used
and gather_from_tp_ranks is called) to assert megatron_weights.size(0) %
self.tp_size == 0 and raise/abort with a clear message including
megatron_weights.size(0) and self.tp_size so the conversion fails loudly instead
of producing unequal shards that corrupt GatedMLPMapping gate/split logic.

Comment on lines +98 to +99
model.save_pretrained(model_dir, safe_serialization=True)
model.save_pretrained(model_dir, safe_serialization=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate model.save_pretrained call.

model.save_pretrained(model_dir, safe_serialization=True) is called twice in succession. This appears to be a copy-paste error — the second call is redundant.

Proposed fix
         model.save_pretrained(model_dir, safe_serialization=True)
-        model.save_pretrained(model_dir, safe_serialization=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model.save_pretrained(model_dir, safe_serialization=True)
model.save_pretrained(model_dir, safe_serialization=True)
model.save_pretrained(model_dir, safe_serialization=True)
🤖 Prompt for AI Agents
In `@tests/functional_tests/converter/test_hf_fsdp_conversion.py` around lines 98
- 99, The test contains a duplicated call to model.save_pretrained(model_dir,
safe_serialization=True); remove the redundant second invocation so the model is
only saved once. Locate the duplicate calls to model.save_pretrained in
test_hf_fsdp_conversion.py (within the test function where the model is prepared
and saved) and delete the extra line, leaving a single
model.save_pretrained(model_dir, safe_serialization=True) call.

Comment on lines +145 to +154
try:
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent
)

# Check that the conversion completed successfully
if result.returncode != 0:
print(f"STDOUT: {result.stdout}")
print(f"STDERR: {result.stderr}")
assert False, f"FSDP Roundtrip failed with return code {result.returncode}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use pytest.fail() instead of assert False and add a timeout.

assert False is stripped under python -O (as flagged by static analysis B011). Also, the subprocess has no timeout, so a hung conversion will hang CI indefinitely.

Proposed fix
         try:
             result = subprocess.run(
-                cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent
+                cmd, capture_output=True, text=True, timeout=600,
+                cwd=Path(__file__).parent.parent.parent.parent,
             )

             # Check that the conversion completed successfully
             if result.returncode != 0:
                 print(f"STDOUT: {result.stdout}")
                 print(f"STDERR: {result.stderr}")
-                assert False, f"FSDP Roundtrip failed with return code {result.returncode}"
+                pytest.fail(f"FSDP Roundtrip failed with return code {result.returncode}")
🧰 Tools
🪛 Ruff (0.14.14)

[error] 146-146: subprocess call: check for execution of untrusted input

(S603)


[warning] 154-154: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

🤖 Prompt for AI Agents
In `@tests/functional_tests/converter/test_hf_fsdp_conversion.py` around lines 145
- 154, Replace the use of "assert False" with pytest.fail and add a timeout to
the subprocess.run call to avoid hangs; specifically, when running the
conversion subprocess (the cmd variable passed into subprocess.run) add a
timeout argument (e.g., timeout=XXX) and on non-zero return use pytest.fail(...)
instead of assert False, including result.returncode, result.stdout and
result.stderr in the failure message so the test prints useful diagnostics
(locate the subprocess.run call and the subsequent if result.returncode != 0
block referencing result and cmd).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants