|
| 1 | +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Shared config-only AutoBridge/provider contracts for high-risk model families.""" |
| 16 | + |
| 17 | +from collections.abc import Callable, Mapping |
| 18 | +from dataclasses import dataclass |
| 19 | +from importlib import import_module |
| 20 | +from typing import Any |
| 21 | + |
| 22 | +import pytest |
| 23 | +from transformers import PretrainedConfig |
| 24 | + |
| 25 | +from megatron.bridge.models.conversion.auto_bridge import AutoBridge |
| 26 | +from megatron.bridge.models.qwen3_asr.hf_qwen3_asr.configuration_qwen3_asr import Qwen3ASRConfig |
| 27 | +from megatron.bridge.models.stepfun.configuration_step35 import Step35Config |
| 28 | + |
| 29 | + |
| 30 | +pytestmark = [pytest.mark.unit] |
| 31 | + |
| 32 | + |
| 33 | +@dataclass(frozen=True) |
| 34 | +class ModelProviderContractCase: |
| 35 | + """Config-only bridge/provider contract for one HF architecture.""" |
| 36 | + |
| 37 | + name: str |
| 38 | + architecture: str |
| 39 | + config_factory: Callable[[], PretrainedConfig] |
| 40 | + bridge_symbol: str |
| 41 | + provider_symbol: str |
| 42 | + expected_provider_attrs: Mapping[str, Any] |
| 43 | + |
| 44 | + |
| 45 | +def _resolve_symbol(qualified_name: str) -> type: |
| 46 | + module_name, symbol_name = qualified_name.rsplit(".", 1) |
| 47 | + return getattr(import_module(module_name), symbol_name) |
| 48 | + |
| 49 | + |
| 50 | +def _make_qwen3_asr_config() -> Qwen3ASRConfig: |
| 51 | + return Qwen3ASRConfig( |
| 52 | + architectures=["Qwen3ASRForConditionalGeneration"], |
| 53 | + thinker_config={ |
| 54 | + "torch_dtype": "bfloat16", |
| 55 | + "audio_config": { |
| 56 | + "encoder_layers": 2, |
| 57 | + }, |
| 58 | + "text_config": { |
| 59 | + "hidden_size": 128, |
| 60 | + "intermediate_size": 256, |
| 61 | + "num_hidden_layers": 2, |
| 62 | + "num_attention_heads": 4, |
| 63 | + "num_key_value_heads": 2, |
| 64 | + "vocab_size": 512, |
| 65 | + "max_position_embeddings": 1024, |
| 66 | + "initializer_range": 0.02, |
| 67 | + "rms_norm_eps": 1e-6, |
| 68 | + "rope_theta": 5000000.0, |
| 69 | + "tie_word_embeddings": False, |
| 70 | + }, |
| 71 | + }, |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +def _make_step35_config() -> Step35Config: |
| 76 | + return Step35Config( |
| 77 | + hidden_size=128, |
| 78 | + intermediate_size=256, |
| 79 | + num_attention_heads=4, |
| 80 | + num_attention_groups=2, |
| 81 | + num_hidden_layers=4, |
| 82 | + vocab_size=512, |
| 83 | + max_position_embeddings=1024, |
| 84 | + moe_intermediate_size=64, |
| 85 | + moe_num_experts=4, |
| 86 | + moe_top_k=2, |
| 87 | + share_expert_dim=64, |
| 88 | + head_dim=32, |
| 89 | + layer_types=[ |
| 90 | + "full_attention", |
| 91 | + "sliding_attention", |
| 92 | + "full_attention", |
| 93 | + "sliding_attention", |
| 94 | + "full_attention", |
| 95 | + "sliding_attention", |
| 96 | + ], |
| 97 | + attention_other_setting={ |
| 98 | + "attention_type": "sliding_attention", |
| 99 | + "num_attention_heads": 4, |
| 100 | + "num_attention_groups": 2, |
| 101 | + "head_dim": 32, |
| 102 | + }, |
| 103 | + sliding_window=128, |
| 104 | + num_nextn_predict_layers=2, |
| 105 | + moe_layers_enum=(2, 3), |
| 106 | + torch_dtype="bfloat16", |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def _make_mimo_v2_flash_config() -> PretrainedConfig: |
| 111 | + return PretrainedConfig( |
| 112 | + architectures=["MiMoV2FlashForCausalLM"], |
| 113 | + model_type="mimo_v2_flash", |
| 114 | + num_hidden_layers=6, |
| 115 | + hidden_size=256, |
| 116 | + intermediate_size=512, |
| 117 | + num_attention_heads=8, |
| 118 | + num_key_value_heads=2, |
| 119 | + head_dim=32, |
| 120 | + vocab_size=1024, |
| 121 | + max_position_embeddings=2048, |
| 122 | + rope_theta=5000000, |
| 123 | + rms_norm_eps=1e-5, |
| 124 | + initializer_range=0.02, |
| 125 | + tie_word_embeddings=False, |
| 126 | + attention_bias=False, |
| 127 | + mlp_bias=False, |
| 128 | + hidden_act="silu", |
| 129 | + layernorm_epsilon=1e-5, |
| 130 | + v_head_dim=16, |
| 131 | + hybrid_layer_pattern=[0, 1, 1, 1, 0, 1], |
| 132 | + sliding_window_size=128, |
| 133 | + sliding_window=128, |
| 134 | + attention_chunk_size=128, |
| 135 | + swa_rope_theta=10000, |
| 136 | + swa_num_key_value_heads=4, |
| 137 | + swa_num_attention_heads=8, |
| 138 | + swa_head_dim=32, |
| 139 | + swa_v_head_dim=16, |
| 140 | + add_swa_attention_sink_bias=True, |
| 141 | + add_full_attention_sink_bias=False, |
| 142 | + attention_value_scale=0.707, |
| 143 | + moe_layer_freq=[0, 1, 1, 1, 1, 1], |
| 144 | + n_routed_experts=8, |
| 145 | + moe_intermediate_size=128, |
| 146 | + num_experts_per_tok=2, |
| 147 | + scoring_func="sigmoid", |
| 148 | + n_shared_experts=None, |
| 149 | + n_group=1, |
| 150 | + topk_group=1, |
| 151 | + topk_method="noaux_tc", |
| 152 | + norm_topk_prob=True, |
| 153 | + routed_scaling_factor=None, |
| 154 | + torch_dtype="bfloat16", |
| 155 | + ) |
| 156 | + |
| 157 | + |
| 158 | +def _make_nemotron_labs_diffusion_config() -> PretrainedConfig: |
| 159 | + text_config = PretrainedConfig( |
| 160 | + hidden_size=128, |
| 161 | + intermediate_size=256, |
| 162 | + num_hidden_layers=2, |
| 163 | + tie_word_embeddings=True, |
| 164 | + rope_parameters={"rope_theta": 10000.0}, |
| 165 | + vocab_size=512, |
| 166 | + ) |
| 167 | + return PretrainedConfig( |
| 168 | + architectures=["NemotronLabsDiffusionModel"], |
| 169 | + model_type="nemotron_labs_diffusion", |
| 170 | + text_config=text_config, |
| 171 | + ) |
| 172 | + |
| 173 | + |
| 174 | +G_CONTRACT_CASES = ( |
| 175 | + ModelProviderContractCase( |
| 176 | + name="qwen3_asr_nested_config", |
| 177 | + architecture="Qwen3ASRForConditionalGeneration", |
| 178 | + config_factory=_make_qwen3_asr_config, |
| 179 | + bridge_symbol="megatron.bridge.models.qwen3_asr.qwen3_asr_bridge.Qwen3ASRBridge", |
| 180 | + provider_symbol="megatron.bridge.models.qwen3_asr.qwen3_asr_provider.Qwen3ASRModelProvider", |
| 181 | + expected_provider_attrs={ |
| 182 | + "hidden_size": 128, |
| 183 | + "num_layers": 2, |
| 184 | + "num_query_groups": 2, |
| 185 | + "vocab_size": 512, |
| 186 | + "audio_token_id": 151646, |
| 187 | + "share_embeddings_and_output_weights": False, |
| 188 | + }, |
| 189 | + ), |
| 190 | + ModelProviderContractCase( |
| 191 | + name="step35_mtp_layer_types", |
| 192 | + architecture="Step3p5ForCausalLM", |
| 193 | + config_factory=_make_step35_config, |
| 194 | + bridge_symbol="megatron.bridge.models.stepfun.step35_bridge.Step35Bridge", |
| 195 | + provider_symbol="megatron.bridge.models.stepfun.step35_provider.Step35ModelProvider", |
| 196 | + expected_provider_attrs={ |
| 197 | + "hidden_size": 128, |
| 198 | + "num_layers": 4, |
| 199 | + "num_query_groups": 2, |
| 200 | + "num_moe_experts": 4, |
| 201 | + "moe_router_topk": 2, |
| 202 | + "moe_layer_freq": [0, 0, 1, 1], |
| 203 | + "layer_types": [ |
| 204 | + "full_attention", |
| 205 | + "sliding_attention", |
| 206 | + "full_attention", |
| 207 | + "sliding_attention", |
| 208 | + "full_attention", |
| 209 | + "sliding_attention", |
| 210 | + ], |
| 211 | + }, |
| 212 | + ), |
| 213 | + ModelProviderContractCase( |
| 214 | + name="mimo_v2_flash_registration", |
| 215 | + architecture="MiMoV2FlashForCausalLM", |
| 216 | + config_factory=_make_mimo_v2_flash_config, |
| 217 | + bridge_symbol="megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_bridge.MiMoV2FlashBridge", |
| 218 | + provider_symbol="megatron.bridge.models.mimo_v2_flash.mimo_v2_flash_provider.MiMoV2FlashModelProvider", |
| 219 | + expected_provider_attrs={ |
| 220 | + "hidden_size": 256, |
| 221 | + "num_layers": 6, |
| 222 | + "num_query_groups": 2, |
| 223 | + "full_attn_num_query_groups": 2, |
| 224 | + "swa_num_query_groups": 4, |
| 225 | + "v_head_dim": 16, |
| 226 | + "window_size": 128, |
| 227 | + "mtp_num_layers": 0, |
| 228 | + }, |
| 229 | + ), |
| 230 | + ModelProviderContractCase( |
| 231 | + name="nemotron_labs_diffusion_text_config", |
| 232 | + architecture="NemotronLabsDiffusionModel", |
| 233 | + config_factory=_make_nemotron_labs_diffusion_config, |
| 234 | + bridge_symbol=( |
| 235 | + "megatron.bridge.diffusion.conversion.nemotron_labs_diffusion." |
| 236 | + "nemotron_labs_diffusion_bridge.NemotronLabsDiffusionBridge" |
| 237 | + ), |
| 238 | + provider_symbol=( |
| 239 | + "megatron.bridge.diffusion.models.nemotron_labs_diffusion." |
| 240 | + "nemotron_labs_diffusion_provider.NemotronLabsDiffusionModelProvider" |
| 241 | + ), |
| 242 | + expected_provider_attrs={ |
| 243 | + "hidden_size": 128, |
| 244 | + "ffn_hidden_size": 256, |
| 245 | + "num_layers": 2, |
| 246 | + "vocab_size": 512, |
| 247 | + "share_embeddings_and_output_weights": True, |
| 248 | + "rotary_base": 10000.0, |
| 249 | + }, |
| 250 | + ), |
| 251 | +) |
| 252 | + |
| 253 | + |
| 254 | +@pytest.mark.parametrize("case", G_CONTRACT_CASES, ids=[case.name for case in G_CONTRACT_CASES]) |
| 255 | +def test_config_only_autobridge_provider_contract(case: ModelProviderContractCase) -> None: |
| 256 | + bridge_type = _resolve_symbol(case.bridge_symbol) |
| 257 | + provider_type = _resolve_symbol(case.provider_symbol) |
| 258 | + config = case.config_factory() |
| 259 | + |
| 260 | + assert AutoBridge.supports(config) is True |
| 261 | + assert case.architecture in AutoBridge.list_supported_models() |
| 262 | + |
| 263 | + bridge = AutoBridge.from_hf_config(config) |
| 264 | + assert isinstance(bridge._model_bridge, bridge_type) |
| 265 | + |
| 266 | + provider = bridge.to_megatron_provider(load_weights=False) |
| 267 | + assert isinstance(provider, provider_type) |
| 268 | + for attr_name, expected_value in case.expected_provider_attrs.items(): |
| 269 | + assert getattr(provider, attr_name) == expected_value |
0 commit comments