Skip to content

Commit 2563be9

Browse files
yaoyu-33NeMo Bot
authored andcommitted
[model] fix: pass Step35 layer names through MTP (#4368)
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
1 parent f6fc884 commit 2563be9

2 files changed

Lines changed: 56 additions & 12 deletions

File tree

src/megatron/bridge/models/stepfun/step35_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
is_mtp_layer: bool = False,
9090
add_layer_offset: bool = True,
9191
pp_layer_offset: Optional[int] = None,
92+
name: str | None = None,
9293
):
9394
pp_rank = get_pg_rank(pg_collection.pp)
9495
if is_mtp_layer:
@@ -139,6 +140,7 @@ def __init__(
139140
is_mtp_layer=is_mtp_layer,
140141
add_layer_offset=add_layer_offset,
141142
pp_layer_offset=pp_layer_offset,
143+
name=name,
142144
)
143145

144146

tests/unit_tests/models/stepfun/test_step35_provider.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,11 @@ class _SuperInitRecorder:
145145

146146
def __init__(self):
147147
self.captured_config = None
148+
self.captured_kwargs = None
148149

149-
def __call__(self, instance, config, **_):
150+
def __call__(self, instance, config, **kwargs):
150151
self.captured_config = config
152+
self.captured_kwargs = kwargs
151153

152154

153155
class TestStep35DecoderLayerIsSliding:
@@ -166,6 +168,31 @@ def _build(
166168
sliding_setting=_UNSET,
167169
offset_return=0,
168170
pp_rank=0,
171+
):
172+
config, recorder = self._build_with_recorder(
173+
layer_number=layer_number,
174+
is_mtp_layer=is_mtp_layer,
175+
add_layer_offset=add_layer_offset,
176+
layer_types=layer_types,
177+
attention_other_setting=attention_other_setting,
178+
sliding_setting=sliding_setting,
179+
offset_return=offset_return,
180+
pp_rank=pp_rank,
181+
)
182+
return config, recorder.captured_config
183+
184+
def _build_with_recorder(
185+
self,
186+
*,
187+
layer_number,
188+
is_mtp_layer=False,
189+
add_layer_offset=True,
190+
layer_types=None,
191+
attention_other_setting=True,
192+
sliding_setting=_UNSET,
193+
offset_return=0,
194+
pp_rank=0,
195+
name=_UNSET,
169196
):
170197
layer_types = (
171198
layer_types
@@ -190,17 +217,20 @@ def _build(
190217
return_value=offset_return,
191218
),
192219
):
193-
Step35DecoderLayer(
194-
config=config,
195-
submodules=None,
196-
layer_number=layer_number,
197-
pg_collection=SimpleNamespace(pp="dummy"),
198-
vp_stage=None,
199-
is_mtp_layer=is_mtp_layer,
200-
add_layer_offset=add_layer_offset,
201-
)
202-
203-
return config, recorder.captured_config
220+
layer_kwargs = {
221+
"config": config,
222+
"submodules": None,
223+
"layer_number": layer_number,
224+
"pg_collection": SimpleNamespace(pp="dummy"),
225+
"vp_stage": None,
226+
"is_mtp_layer": is_mtp_layer,
227+
"add_layer_offset": add_layer_offset,
228+
}
229+
if name is not _UNSET:
230+
layer_kwargs["name"] = name
231+
Step35DecoderLayer(**layer_kwargs)
232+
233+
return config, recorder
204234

205235
def test_full_attention_keeps_original_config(self):
206236
original, captured = self._build(layer_number=1) # layer_idx=0 -> full_attention
@@ -232,6 +262,18 @@ def test_mtp_layer_uses_global_layer_index_after_main_decoder(self):
232262
)
233263
assert captured is original # full attention
234264

265+
def test_mtp_layer_forwards_name_to_transformer_layer(self):
266+
"""MCore's MTP builder passes ``name`` into the nested transformer layer."""
267+
name = "decoder.layers.0.mtp_model_layer"
268+
original, recorder = self._build_with_recorder(
269+
layer_number=1,
270+
is_mtp_layer=True,
271+
name=name,
272+
)
273+
274+
assert recorder.captured_config is original
275+
assert recorder.captured_kwargs["name"] == name
276+
235277
def test_pp_offset_applied_for_main_decoder(self):
236278
"""With ``add_layer_offset=True`` the resolved index is
237279
``layer_number + get_transformer_layer_offset(...) - 1``; the test forces

0 commit comments

Comments
 (0)