Skip to content

Commit 6991dea

Browse files
committed
[test] test: add model contract matrix
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent 579f5c8 commit 6991dea

1 file changed

Lines changed: 269 additions & 0 deletions

File tree

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shared config-only AutoBridge/provider contracts for high-risk model families."""
16+
17+
from collections.abc import Callable, Mapping
18+
from dataclasses import dataclass
19+
from importlib import import_module
20+
from typing import Any
21+
22+
import pytest
23+
from transformers import PretrainedConfig
24+
25+
from megatron.bridge.models.conversion.auto_bridge import AutoBridge
26+
from megatron.bridge.models.qwen3_asr.hf_qwen3_asr.configuration_qwen3_asr import Qwen3ASRConfig
27+
from megatron.bridge.models.stepfun.configuration_step35 import Step35Config
28+
29+
30+
pytestmark = [pytest.mark.unit]
31+
32+
33+
@dataclass(frozen=True)
34+
class ModelProviderContractCase:
35+
"""Config-only bridge/provider contract for one HF architecture."""
36+
37+
name: str
38+
architecture: str
39+
config_factory: Callable[[], PretrainedConfig]
40+
bridge_symbol: str
41+
provider_symbol: str
42+
expected_provider_attrs: Mapping[str, Any]
43+
44+
45+
def _resolve_symbol(qualified_name: str) -> type:
46+
module_name, symbol_name = qualified_name.rsplit(".", 1)
47+
return getattr(import_module(module_name), symbol_name)
48+
49+
50+
def _make_qwen3_asr_config() -> Qwen3ASRConfig:
51+
return Qwen3ASRConfig(
52+
architectures=["Qwen3ASRForConditionalGeneration"],
53+
thinker_config={
54+
"torch_dtype": "bfloat16",
55+
"audio_config": {
56+
"encoder_layers": 2,
57+
},
58+
"text_config": {
59+
"hidden_size": 128,
60+
"intermediate_size": 256,
61+
"num_hidden_layers": 2,
62+
"num_attention_heads": 4,
63+
"num_key_value_heads": 2,
64+
"vocab_size": 512,
65+
"max_position_embeddings": 1024,
66+
"initializer_range": 0.02,
67+
"rms_norm_eps": 1e-6,
68+
"rope_theta": 5000000.0,
69+
"tie_word_embeddings": False,
70+
},
71+
},
72+
)
73+
74+
75+
def _make_step35_config() -> Step35Config:
76+
return Step35Config(
77+
hidden_size=128,
78+
intermediate_size=256,
79+
num_attention_heads=4,
80+
num_attention_groups=2,
81+
num_hidden_layers=4,
82+
vocab_size=512,
83+
max_position_embeddings=1024,
84+
moe_intermediate_size=64,
85+
moe_num_experts=4,
86+
moe_top_k=2,
87+
share_expert_dim=64,
88+
head_dim=32,
89+
layer_types=[
90+
"full_attention",
91+
"sliding_attention",
92+
"full_attention",
93+
"sliding_attention",
94+
"full_attention",
95+
"sliding_attention",
96+
],
97+
attention_other_setting={
98+
"attention_type": "sliding_attention",
99+
"num_attention_heads": 4,
100+
"num_attention_groups": 2,
101+
"head_dim": 32,
102+
},
103+
sliding_window=128,
104+
num_nextn_predict_layers=2,
105+
moe_layers_enum=(2, 3),
106+
torch_dtype="bfloat16",
107+
)
108+
109+
110+
def _make_mimo_v2_flash_config() -> PretrainedConfig:
111+
return PretrainedConfig(
112+
architectures=["MiMoV2FlashForCausalLM"],
113+
model_type="mimo_v2_flash",
114+
num_hidden_layers=6,
115+
hidden_size=256,
116+
intermediate_size=512,
117+
num_attention_heads=8,
118+
num_key_value_heads=2,
119+
head_dim=32,
120+
vocab_size=1024,
121+
max_position_embeddings=2048,
122+
rope_theta=5000000,
123+
rms_norm_eps=1e-5,
124+
initializer_range=0.02,
125+
tie_word_embeddings=False,
126+
attention_bias=False,
127+
mlp_bias=False,
128+
hidden_act="silu",
129+
layernorm_epsilon=1e-5,
130+
v_head_dim=16,
131+
hybrid_layer_pattern=[0, 1, 1, 1, 0, 1],
132+
sliding_window_size=128,
133+
sliding_window=128,
134+
attention_chunk_size=128,
135+
swa_rope_theta=10000,
136+
swa_num_key_value_heads=4,
137+
swa_num_attention_heads=8,
138+
swa_head_dim=32,
139+
swa_v_head_dim=16,
140+
add_swa_attention_sink_bias=True,
141+
add_full_attention_sink_bias=False,
142+
attention_value_scale=0.707,
143+
moe_layer_freq=[0, 1, 1, 1, 1, 1],
144+
n_routed_experts=8,
145+
moe_intermediate_size=128,
146+
num_experts_per_tok=2,
147+
scoring_func="sigmoid",
148+
n_shared_experts=None,
149+
n_group=1,
150+
topk_group=1,
151+
topk_method="noaux_tc",
152+
norm_topk_prob=True,
153+
routed_scaling_factor=None,
154+
torch_dtype="bfloat16",
155+
)
156+
157+
158+
def _make_nemotron_labs_diffusion_config() -> PretrainedConfig:
159+
text_config = PretrainedConfig(
160+
hidden_size=128,
161+
intermediate_size=256,
162+
num_hidden_layers=2,
163+
tie_word_embeddings=True,
164+
rope_parameters={"rope_theta": 10000.0},
165+
vocab_size=512,
166+
)
167+
return PretrainedConfig(
168+
architectures=["NemotronLabsDiffusionModel"],
169+
model_type="nemotron_labs_diffusion",
170+
text_config=text_config,
171+
)
172+
173+
174+
G_CONTRACT_CASES = (
175+
ModelProviderContractCase(
176+
name="qwen3_asr_nested_config",
177+
architecture="Qwen3ASRForConditionalGeneration",
178+
config_factory=_make_qwen3_asr_config,
179+
bridge_symbol="megatron.bridge.models.qwen3_asr.qwen3_asr_bridge.Qwen3ASRBridge",
180+
provider_symbol="megatron.bridge.models.qwen3_asr.qwen3_asr_provider.Qwen3ASRModelProvider",
181+
expected_provider_attrs={
182+
"hidden_size": 128,
183+
"num_layers": 2,
184+
"num_query_groups": 2,
185+
"vocab_size": 512,
186+
"audio_token_id": 151646,
187+
"share_embeddings_and_output_weights": False,
188+
},
189+
),
190+
ModelProviderContractCase(
191+
name="step35_mtp_layer_types",
192+
architecture="Step3p5ForCausalLM",
193+
config_factory=_make_step35_config,
194+
bridge_symbol="megatron.bridge.models.stepfun.step35_bridge.Step35Bridge",
195+
provider_symbol="megatron.bridge.models.stepfun.step35_provider.Step35ModelProvider",
196+
expected_provider_attrs={
197+
"hidden_size": 128,
198+
"num_layers": 4,
199+
"num_query_groups": 2,
200+
"num_moe_experts": 4,
201+
"moe_router_topk": 2,
202+
"moe_layer_freq": [0, 0, 1, 1],
203+
"layer_types": [
204+
"full_attention",
205+
"sliding_attention",
206+
"full_attention",
207+
"sliding_attention",
208+
"full_attention",
209+
"sliding_attention",
210+
],
211+
},
212+
),
213+
ModelProviderContractCase(
214+
name="mimo_v2_flash_registration",
215+
architecture="MiMoV2FlashForCausalLM",
216+
config_factory=_make_mimo_v2_flash_config,
217+
bridge_symbol="megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashBridge",
218+
provider_symbol="megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_provider.MiMoV2FlashModelProvider",
219+
expected_provider_attrs={
220+
"hidden_size": 256,
221+
"num_layers": 6,
222+
"num_query_groups": 2,
223+
"full_attn_num_query_groups": 2,
224+
"swa_num_query_groups": 4,
225+
"v_head_dim": 16,
226+
"window_size": 128,
227+
"mtp_num_layers": 0,
228+
},
229+
),
230+
ModelProviderContractCase(
231+
name="nemotron_labs_diffusion_text_config",
232+
architecture="NemotronLabsDiffusionModel",
233+
config_factory=_make_nemotron_labs_diffusion_config,
234+
bridge_symbol=(
235+
"megatron.bridge.diffusion.conversion.nemotron_labs_diffusion."
236+
"nemotron_labs_diffusion_bridge.NemotronLabsDiffusionBridge"
237+
),
238+
provider_symbol=(
239+
"megatron.bridge.diffusion.models.nemotron_labs_diffusion."
240+
"nemotron_labs_diffusion_provider.NemotronLabsDiffusionModelProvider"
241+
),
242+
expected_provider_attrs={
243+
"hidden_size": 128,
244+
"ffn_hidden_size": 256,
245+
"num_layers": 2,
246+
"vocab_size": 512,
247+
"share_embeddings_and_output_weights": True,
248+
"rotary_base": 10000.0,
249+
},
250+
),
251+
)
252+
253+
254+
@pytest.mark.parametrize("case", G_CONTRACT_CASES, ids=[case.name for case in G_CONTRACT_CASES])
255+
def test_config_only_autobridge_provider_contract(case: ModelProviderContractCase) -> None:
256+
bridge_type = _resolve_symbol(case.bridge_symbol)
257+
provider_type = _resolve_symbol(case.provider_symbol)
258+
config = case.config_factory()
259+
260+
assert AutoBridge.supports(config) is True
261+
assert case.architecture in AutoBridge.list_supported_models()
262+
263+
bridge = AutoBridge.from_hf_config(config)
264+
assert isinstance(bridge._model_bridge, bridge_type)
265+
266+
provider = bridge.to_megatron_provider(load_weights=False)
267+
assert isinstance(provider, provider_type)
268+
for attr_name, expected_value in case.expected_provider_attrs.items():
269+
assert getattr(provider, attr_name) == expected_value

0 commit comments

Comments
 (0)