Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/megatron/bridge/models/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,18 @@ def get_causal_lm_class_name_via_auto_map(
def conform_config_to_reference(
hf_config_dict: dict[str, object], reference_config: dict[str, object]
) -> dict[str, object]:
"""Return a projected hf_config_dict onto the reference key set, imputing missing keys with reference values."""
"""Return hf_config_dict projected onto reference keys, filling missing values from reference_config."""
reference_config_keys = set(reference_config.keys())
filtered_config_dict = {key: value for (key, value) in hf_config_dict.items() if key in reference_config_keys}
filtered_config_dict = {}
for key, value in hf_config_dict.items():
if key not in reference_config_keys:
continue

reference_value = reference_config[key]
if isinstance(value, dict) and isinstance(reference_value, dict):
value = conform_config_to_reference(value, reference_value)
filtered_config_dict[key] = value

for key, value in reference_config.items():
if key not in filtered_config_dict:
filtered_config_dict[key] = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ def __init__(
self.text_config = text_config
self.audio_token_id = audio_token_id

def to_cfg_dict(self) -> dict[str, object]:
"""Return an instantiable config dictionary for Megatron Bridge run_config.yaml."""
config = self.to_dict()
config["_target_"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return config


class Qwen3ASRConfig(PretrainedConfig):
"""
Expand Down
74 changes: 74 additions & 0 deletions tests/unit_tests/models/qwen3_asr/test_qwen3_asr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

import pytest

from megatron.bridge.models.conversion.utils import conform_config_to_reference
from megatron.bridge.models.qwen3_asr.hf_qwen3_asr.configuration_qwen3_asr import (
Qwen3ASRConfig,
Qwen3ASRThinkerConfig,
)
from megatron.bridge.training.config import ConfigContainer


pytestmark = [pytest.mark.unit]
Expand Down Expand Up @@ -51,3 +53,75 @@ def test_qwen3_asr_config_from_dict_constructs_thinker_config():
assert isinstance(config.thinker_config, Qwen3ASRThinkerConfig)
assert config.thinker_config.audio_config.encoder_layers == 2
assert config.thinker_config.text_config.hidden_size == 128


def test_qwen3_asr_config_conforming_preserves_reference_audio_subconfig():
reference_config = Qwen3ASRConfig.from_dict(
{
"model_type": "qwen3_asr",
"architectures": ["Qwen3ASRForConditionalGeneration"],
"thinker_config": {
"audio_config": {
"d_model": 1024,
"encoder_layers": 32,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
},
"text_config": {
"hidden_size": 2048,
"intermediate_size": 8192,
"num_hidden_layers": 12,
"num_attention_heads": 16,
"num_key_value_heads": 4,
"vocab_size": 151936,
},
},
}
)
megatron_derived_config = {
"model_type": "qwen3_asr",
"architectures": ["Qwen3ASRForConditionalGeneration"],
"thinker_config": {
"text_config": {
"hidden_size": 3584,
"intermediate_size": 18944,
"num_hidden_layers": 28,
"num_attention_heads": 28,
"num_key_value_heads": 4,
"vocab_size": 151936,
},
},
}

conformed_config = conform_config_to_reference(megatron_derived_config, reference_config.to_dict())
config = Qwen3ASRConfig(**conformed_config)

assert config.thinker_config.audio_config.d_model == 1024
assert config.thinker_config.audio_config.encoder_attention_heads == 16
assert config.thinker_config.text_config.hidden_size == 3584


def test_qwen3_asr_thinker_config_serializes_nested_subconfigs_for_run_config():
config = Qwen3ASRThinkerConfig(
audio_config={
"d_model": 1024,
"encoder_layers": 24,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"output_dim": 2048,
},
text_config={
"hidden_size": 2048,
"intermediate_size": 6144,
"num_hidden_layers": 28,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"vocab_size": 151936,
},
)

serialized = ConfigContainer._convert_value_to_dict(config)

assert serialized["audio_config"]["d_model"] == 1024
assert serialized["text_config"]["hidden_size"] == 2048
assert serialized["_target_"].endswith("Qwen3ASRThinkerConfig")
Loading