@@ -145,9 +145,11 @@ class _SuperInitRecorder:
145145
146146 def __init__ (self ):
147147 self .captured_config = None
148+ self .captured_kwargs = None
148149
149- def __call__ (self , instance , config , ** _ ):
150+ def __call__ (self , instance , config , ** kwargs ):
150151 self .captured_config = config
152+ self .captured_kwargs = kwargs
151153
152154
153155class TestStep35DecoderLayerIsSliding :
@@ -166,6 +168,31 @@ def _build(
166168 sliding_setting = _UNSET ,
167169 offset_return = 0 ,
168170 pp_rank = 0 ,
171+ ):
172+ config , recorder = self ._build_with_recorder (
173+ layer_number = layer_number ,
174+ is_mtp_layer = is_mtp_layer ,
175+ add_layer_offset = add_layer_offset ,
176+ layer_types = layer_types ,
177+ attention_other_setting = attention_other_setting ,
178+ sliding_setting = sliding_setting ,
179+ offset_return = offset_return ,
180+ pp_rank = pp_rank ,
181+ )
182+ return config , recorder .captured_config
183+
184+ def _build_with_recorder (
185+ self ,
186+ * ,
187+ layer_number ,
188+ is_mtp_layer = False ,
189+ add_layer_offset = True ,
190+ layer_types = None ,
191+ attention_other_setting = True ,
192+ sliding_setting = _UNSET ,
193+ offset_return = 0 ,
194+ pp_rank = 0 ,
195+ name = _UNSET ,
169196 ):
170197 layer_types = (
171198 layer_types
@@ -190,17 +217,20 @@ def _build(
190217 return_value = offset_return ,
191218 ),
192219 ):
193- Step35DecoderLayer (
194- config = config ,
195- submodules = None ,
196- layer_number = layer_number ,
197- pg_collection = SimpleNamespace (pp = "dummy" ),
198- vp_stage = None ,
199- is_mtp_layer = is_mtp_layer ,
200- add_layer_offset = add_layer_offset ,
201- )
202-
203- return config , recorder .captured_config
220+ layer_kwargs = {
221+ "config" : config ,
222+ "submodules" : None ,
223+ "layer_number" : layer_number ,
224+ "pg_collection" : SimpleNamespace (pp = "dummy" ),
225+ "vp_stage" : None ,
226+ "is_mtp_layer" : is_mtp_layer ,
227+ "add_layer_offset" : add_layer_offset ,
228+ }
229+ if name is not _UNSET :
230+ layer_kwargs ["name" ] = name
231+ Step35DecoderLayer (** layer_kwargs )
232+
233+ return config , recorder
204234
205235 def test_full_attention_keeps_original_config (self ):
206236 original , captured = self ._build (layer_number = 1 ) # layer_idx=0 -> full_attention
@@ -232,6 +262,18 @@ def test_mtp_layer_uses_global_layer_index_after_main_decoder(self):
232262 )
233263 assert captured is original # full attention
234264
265+ def test_mtp_layer_forwards_name_to_transformer_layer (self ):
266+ """MCore's MTP builder passes ``name`` into the nested transformer layer."""
267+ name = "decoder.layers.0.mtp_model_layer"
268+ original , recorder = self ._build_with_recorder (
269+ layer_number = 1 ,
270+ is_mtp_layer = True ,
271+ name = name ,
272+ )
273+
274+ assert recorder .captured_config is original
275+ assert recorder .captured_kwargs ["name" ] == name
276+
235277 def test_pp_offset_applied_for_main_decoder (self ):
236278 """With ``add_layer_offset=True`` the resolved index is
237279 ``layer_number + get_transformer_layer_offset(...) - 1``; the test forces
0 commit comments