Skip to content

Commit 6035e7c

Browse files
yfwclaude[bot]claude
authored
Allow mtp_num_layers to be None (#4216)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Signed-off-by: Yi-Fu Wu <yifuw@nvidia.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 8e002d8 commit 6035e7c

2 files changed

Lines changed: 50 additions & 3 deletions

File tree

src/megatron/bridge/models/mamba/mamba_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
137137
_pg_collection: Optional[ProcessGroupCollection] = None
138138

139139
# MTP
140-
mtp_num_layers: int = 0
140+
mtp_num_layers: int | None = 0
141141
mtp_hybrid_override_pattern: Optional[str] = None
142142
keep_mtp_spec_in_bf16: bool = False
143143

@@ -182,9 +182,9 @@ def finalize(self) -> None:
182182
# Include the pattern at least once so the MTP block (and its weights)
183183
# are created even when mtp_num_layers=0.
184184
if self.mtp_use_repeated_layer:
185-
num_pattern_copies = max(1, self.mtp_num_layers)
185+
num_pattern_copies = max(1, self.mtp_num_layers or 0)
186186
else:
187-
num_pattern_copies = self.mtp_num_layers
187+
num_pattern_copies = self.mtp_num_layers or 0
188188
self.hybrid_layer_pattern = (
189189
main_pattern + sep + sep.join([self.mtp_hybrid_override_pattern] * num_pattern_copies)
190190
)

tests/unit_tests/models/mamba/test_mamba_provider.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,50 @@ def test_finalize_uses_compatible_hybrid_layer_count(self):
315315

316316
assert provider.num_layers == 9
317317
mock_finalize.assert_called_once_with(provider)
318+
319+
def test_finalize_mtp_num_layers_none_with_repeated_layer(self):
320+
"""finalize must not crash when mtp_num_layers is None and mtp_use_repeated_layer is True.
321+
322+
With repeated layers the shared MTP block is always materialized at least once,
323+
so a None mtp_num_layers must be coerced to 0 before the max(1, ...) clamp.
324+
"""
325+
sep = mamba_provider.Symbols.MTP_SEPARATOR
326+
provider = MambaModelProvider(
327+
hidden_size=128,
328+
num_attention_heads=1,
329+
hybrid_layer_pattern="M-M-M-M-",
330+
mtp_hybrid_override_pattern="M*",
331+
mtp_num_layers=None,
332+
mtp_use_repeated_layer=True,
333+
)
334+
335+
with patch.object(mamba_provider.TransformerConfig, "finalize", autospec=True):
336+
provider.finalize()
337+
338+
# The shared MTP block is included exactly once (max(1, None or 0) == 1).
339+
assert provider.hybrid_layer_pattern == "M-M-M-M-" + sep + "M*"
340+
# mtp_num_layers is inferred from the constructed pattern rather than left as None.
341+
assert provider.mtp_num_layers is not None
342+
343+
def test_finalize_mtp_num_layers_none_without_repeated_layer(self):
344+
"""finalize must not crash when mtp_num_layers is None and mtp_use_repeated_layer is False.
345+
346+
Without repeated layers the copy count is mtp_num_layers directly; a None value must be
347+
coerced to 0 so the pattern construction (`[pattern] * count`) does not raise TypeError.
348+
"""
349+
sep = mamba_provider.Symbols.MTP_SEPARATOR
350+
provider = MambaModelProvider(
351+
hidden_size=128,
352+
num_attention_heads=1,
353+
hybrid_layer_pattern="M-M-M-M-",
354+
mtp_hybrid_override_pattern="M*",
355+
mtp_num_layers=None,
356+
mtp_use_repeated_layer=False,
357+
)
358+
359+
with patch.object(mamba_provider.TransformerConfig, "finalize", autospec=True):
360+
provider.finalize()
361+
362+
# Zero copies of the MTP block are appended (None or 0 == 0).
363+
assert provider.hybrid_layer_pattern == "M-M-M-M-" + sep
364+
assert provider.mtp_num_layers is None

0 commit comments

Comments
 (0)