From 7b6e550feba2e8115d190a625cb6e2d0f3a490d2 Mon Sep 17 00:00:00 2001 From: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Date: Mon, 15 Jun 2026 19:00:27 -0700 Subject: [PATCH] fix(qwen3-asr): preserve audio config for checkpoint export (#4361) Signed-off-by: yaoyu-33 Signed-off-by: NeMo Bot --- .../bridge/models/conversion/utils.py | 13 +++- .../hf_qwen3_asr/configuration_qwen3_asr.py | 6 ++ .../models/qwen3_asr/test_qwen3_asr_config.py | 74 +++++++++++++++++++ 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/megatron/bridge/models/conversion/utils.py b/src/megatron/bridge/models/conversion/utils.py index 5bb44fb0a1..3e3e1a4f63 100644 --- a/src/megatron/bridge/models/conversion/utils.py +++ b/src/megatron/bridge/models/conversion/utils.py @@ -309,9 +309,18 @@ def get_causal_lm_class_name_via_auto_map( def conform_config_to_reference( hf_config_dict: dict[str, object], reference_config: dict[str, object] ) -> dict[str, object]: - """Return a projected hf_config_dict onto the reference key set, imputing missing keys with reference values.""" + """Return hf_config_dict projected onto reference keys, filling missing values from reference_config.""" reference_config_keys = set(reference_config.keys()) - filtered_config_dict = {key: value for (key, value) in hf_config_dict.items() if key in reference_config_keys} + filtered_config_dict = {} + for key, value in hf_config_dict.items(): + if key not in reference_config_keys: + continue + + reference_value = reference_config[key] + if isinstance(value, dict) and isinstance(reference_value, dict): + value = conform_config_to_reference(value, reference_value) + filtered_config_dict[key] = value + for key, value in reference_config.items(): if key not in filtered_config_dict: filtered_config_dict[key] = value diff --git a/src/megatron/bridge/models/qwen3_asr/hf_qwen3_asr/configuration_qwen3_asr.py b/src/megatron/bridge/models/qwen3_asr/hf_qwen3_asr/configuration_qwen3_asr.py index a764693ad5..dcd629a5e6 100644 --- a/src/megatron/bridge/models/qwen3_asr/hf_qwen3_asr/configuration_qwen3_asr.py +++ b/src/megatron/bridge/models/qwen3_asr/hf_qwen3_asr/configuration_qwen3_asr.py @@ -354,6 +354,12 @@ def __init__( self.text_config = text_config self.audio_token_id = audio_token_id + def to_cfg_dict(self) -> dict[str, object]: + """Return an instantiable config dictionary for Megatron Bridge run_config.yaml.""" + config = self.to_dict() + config["_target_"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return config + class Qwen3ASRConfig(PretrainedConfig): """ diff --git a/tests/unit_tests/models/qwen3_asr/test_qwen3_asr_config.py b/tests/unit_tests/models/qwen3_asr/test_qwen3_asr_config.py index 3402d84411..baa6385590 100644 --- a/tests/unit_tests/models/qwen3_asr/test_qwen3_asr_config.py +++ b/tests/unit_tests/models/qwen3_asr/test_qwen3_asr_config.py @@ -14,10 +14,12 @@ import pytest +from megatron.bridge.models.conversion.utils import conform_config_to_reference from megatron.bridge.models.qwen3_asr.hf_qwen3_asr.configuration_qwen3_asr import ( Qwen3ASRConfig, Qwen3ASRThinkerConfig, ) +from megatron.bridge.training.config import ConfigContainer pytestmark = [pytest.mark.unit] @@ -51,3 +53,75 @@ def test_qwen3_asr_config_from_dict_constructs_thinker_config(): assert isinstance(config.thinker_config, Qwen3ASRThinkerConfig) assert config.thinker_config.audio_config.encoder_layers == 2 assert config.thinker_config.text_config.hidden_size == 128 + + +def test_qwen3_asr_config_conforming_preserves_reference_audio_subconfig(): + reference_config = Qwen3ASRConfig.from_dict( + { + "model_type": "qwen3_asr", + "architectures": ["Qwen3ASRForConditionalGeneration"], + "thinker_config": { + "audio_config": { + "d_model": 1024, + "encoder_layers": 32, + "encoder_attention_heads": 16, + "encoder_ffn_dim": 4096, + }, + "text_config": { + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layers": 12, + "num_attention_heads": 16, + "num_key_value_heads": 4, + "vocab_size": 151936, + }, + }, + } + ) + megatron_derived_config = { + "model_type": "qwen3_asr", + "architectures": ["Qwen3ASRForConditionalGeneration"], + "thinker_config": { + "text_config": { + "hidden_size": 3584, + "intermediate_size": 18944, + "num_hidden_layers": 28, + "num_attention_heads": 28, + "num_key_value_heads": 4, + "vocab_size": 151936, + }, + }, + } + + conformed_config = conform_config_to_reference(megatron_derived_config, reference_config.to_dict()) + config = Qwen3ASRConfig(**conformed_config) + + assert config.thinker_config.audio_config.d_model == 1024 + assert config.thinker_config.audio_config.encoder_attention_heads == 16 + assert config.thinker_config.text_config.hidden_size == 3584 + + +def test_qwen3_asr_thinker_config_serializes_nested_subconfigs_for_run_config(): + config = Qwen3ASRThinkerConfig( + audio_config={ + "d_model": 1024, + "encoder_layers": 24, + "encoder_attention_heads": 16, + "encoder_ffn_dim": 4096, + "output_dim": 2048, + }, + text_config={ + "hidden_size": 2048, + "intermediate_size": 6144, + "num_hidden_layers": 28, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "vocab_size": 151936, + }, + ) + + serialized = ConfigContainer._convert_value_to_dict(config) + + assert serialized["audio_config"]["d_model"] == 1024 + assert serialized["text_config"]["hidden_size"] == 2048 + assert serialized["_target_"].endswith("Qwen3ASRThinkerConfig")