Skip to content

Commit d179d09

Browse files
committed
[model] Address Qwen3-Omni packseq review comments
Signed-off-by: hbhflw2000 <417911774@qq.com>
1 parent ce916c6 commit d179d09

3 files changed

Lines changed: 87 additions & 30 deletions

File tree

src/megatron/bridge/models/qwen_omni/modeling_qwen3_omni/thinker_model.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,23 @@ def _build_text_only_mrope_position_ids(input_ids: torch.Tensor) -> torch.Tensor
6565
return torch.stack([base, base, base], dim=0)
6666

6767

68+
_AUDIO_ENCODER_CHUNK_SIZE = 100
69+
_AUDIO_ENCODER_TOKENS_PER_FULL_CHUNK = 13
70+
71+
6872
def _get_qwen3_omni_audio_output_lengths(input_lengths: torch.LongTensor) -> torch.LongTensor:
6973
"""Match HF Qwen3-Omni audio encoder forward output lengths."""
7074

71-
input_lengths_leave = input_lengths % 100
75+
# HF Qwen3-Omni audio encoder handles full 100-frame chunks separately.
76+
# Each full chunk contributes 13 output tokens; the remainder goes through
77+
# two stride-2 convolutions followed by a final stride-2 pooling layer.
78+
input_lengths_leave = input_lengths % _AUDIO_ENCODER_CHUNK_SIZE
7279
feat_lengths = (input_lengths_leave - 1) // 2 + 1
73-
return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
80+
return (
81+
((feat_lengths - 1) // 2 + 1 - 1) // 2
82+
+ 1
83+
+ (input_lengths // _AUDIO_ENCODER_CHUNK_SIZE) * _AUDIO_ENCODER_TOKENS_PER_FULL_CHUNK
84+
)
7485

7586

7687
def _configure_multimodal_attn_impl(config: object, attn_impl: str | None) -> None:
@@ -616,26 +627,19 @@ def forward(
616627
position_ids = torch.nn.functional.pad(position_ids, (0, sp_pad_len), mode="replicate")
617628

618629
if self.config.sequence_parallel or cp_size > 1:
619-
if packed_seq_params is None:
620-
visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs(
621-
visual_pos_masks,
622-
deepstack_visual_embeds,
623-
tp_size=tp_size,
624-
tp_rank=tp_rank,
625-
cp_size=cp_size,
626-
cp_rank=cp_rank,
627-
sequence_parallel=self.config.sequence_parallel,
628-
)
629-
elif self.config.sequence_parallel:
630-
visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs(
631-
visual_pos_masks,
632-
deepstack_visual_embeds,
633-
tp_size=tp_size,
634-
tp_rank=tp_rank,
635-
cp_size=1,
636-
cp_rank=0,
637-
sequence_parallel=self.config.sequence_parallel,
638-
)
630+
# Packed THD tensors are already CP-aware after preprocess_packed_seqs;
631+
# only the SP split remains for deepstack embeddings.
632+
split_cp_size = 1 if packed_seq_params is not None else cp_size
633+
split_cp_rank = 0 if packed_seq_params is not None else cp_rank
634+
visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs(
635+
visual_pos_masks,
636+
deepstack_visual_embeds,
637+
tp_size=tp_size,
638+
tp_rank=tp_rank,
639+
cp_size=split_cp_size,
640+
cp_rank=split_cp_rank,
641+
sequence_parallel=self.config.sequence_parallel,
642+
)
639643

640644
if packed_seq_params is not None and position_ids is not None:
641645
position_ids = (
@@ -649,16 +653,20 @@ def forward(
649653
.contiguous()
650654
)
651655
attention_mask = None
652-
self.language_model.rotary_pos_emb.is_thd_format = True
656+
657+
rotary_pos_emb = getattr(self.language_model, "rotary_pos_emb", None)
658+
if rotary_pos_emb is not None:
659+
rotary_pos_emb.is_thd_format = packed_seq_params is not None
660+
661+
if packed_seq_params is not None:
662+
language_model_input_ids = lm_input_ids
663+
elif combined_embeddings is not None:
664+
language_model_input_ids = None
665+
else:
666+
language_model_input_ids = input_ids
653667

654668
return self.language_model(
655-
input_ids=(
656-
lm_input_ids
657-
if packed_seq_params is not None
658-
else None
659-
if combined_embeddings is not None
660-
else input_ids
661-
),
669+
input_ids=language_model_input_ids,
662670
position_ids=position_ids,
663671
attention_mask=attention_mask,
664672
decoder_input=combined_embeddings,

src/megatron/bridge/models/qwen_omni/qwen3_omni_step.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def forward_step(
284284
attention_mask,
285285
position_ids,
286286
pg_collection,
287+
# Keep packed THD lengths TE-friendly even when the recipe toggles FP8 later.
287288
use_fp8_padding=True,
288289
force_to_seq_length=_parallel_size(pg_collection, "pp") > 1 or _parallel_size(pg_collection, "ep") > 1,
289290
seq_length=getattr(config, "seq_length", getattr(state.cfg.model, "seq_length", None)),

tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,54 @@ def _fake_get_rope_index(*args, **kwargs): # noqa: ARG001
455455
assert rope_calls["attention_mask"] is None
456456
assert fake_language_model.forward_kwargs["attention_mask"] is attention_mask
457457

458+
def test_packed_forward_resets_rotary_thd_state(self):
459+
class _MockProcessGroup:
460+
def size(self):
461+
return 1
462+
463+
def rank(self):
464+
return 0
465+
466+
class _FakeLanguageModel(torch.nn.Module):
467+
def __init__(self):
468+
super().__init__()
469+
self.rotary_pos_emb = SimpleNamespace(is_thd_format=False)
470+
self.forward_kwargs = []
471+
472+
def forward(self, **kwargs):
473+
self.forward_kwargs.append(kwargs)
474+
return torch.tensor(0.0)
475+
476+
fake_language_model = _FakeLanguageModel()
477+
thinker = SimpleNamespace(
478+
pg_collection=SimpleNamespace(cp=_MockProcessGroup(), tp=_MockProcessGroup()),
479+
config=SimpleNamespace(sequence_parallel=False),
480+
pre_process=False,
481+
language_model=fake_language_model,
482+
)
483+
484+
input_ids = torch.tensor([[1, 2, 3, 4]])
485+
position_ids = torch.arange(input_ids.size(1)).view(1, 1, -1).expand(3, input_ids.size(0), -1)
486+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
487+
488+
Qwen3OmniThinkerModel.forward(
489+
thinker,
490+
input_ids=input_ids,
491+
position_ids=position_ids,
492+
attention_mask=attention_mask,
493+
packed_seq_params=object(),
494+
)
495+
assert fake_language_model.rotary_pos_emb.is_thd_format is True
496+
497+
Qwen3OmniThinkerModel.forward(
498+
thinker,
499+
input_ids=input_ids,
500+
position_ids=position_ids,
501+
attention_mask=None,
502+
packed_seq_params=None,
503+
)
504+
assert fake_language_model.rotary_pos_emb.is_thd_format is False
505+
458506
def test_audio_forward(self, thinker_config):
459507
model = self._build_model(thinker_config)
460508
if torch.cuda.is_available():

0 commit comments

Comments
 (0)