@@ -65,12 +65,23 @@ def _build_text_only_mrope_position_ids(input_ids: torch.Tensor) -> torch.Tensor
6565 return torch .stack ([base , base , base ], dim = 0 )
6666
6767
68+ _AUDIO_ENCODER_CHUNK_SIZE = 100
69+ _AUDIO_ENCODER_TOKENS_PER_FULL_CHUNK = 13
70+
71+
6872def _get_qwen3_omni_audio_output_lengths (input_lengths : torch .LongTensor ) -> torch .LongTensor :
6973 """Match HF Qwen3-Omni audio encoder forward output lengths."""
7074
71- input_lengths_leave = input_lengths % 100
75+ # HF Qwen3-Omni audio encoder handles full 100-frame chunks separately.
76+ # Each full chunk contributes 13 output tokens; the remainder goes through
77+ # two stride-2 convolutions followed by a final stride-2 pooling layer.
78+ input_lengths_leave = input_lengths % _AUDIO_ENCODER_CHUNK_SIZE
7279 feat_lengths = (input_lengths_leave - 1 ) // 2 + 1
73- return ((feat_lengths - 1 ) // 2 + 1 - 1 ) // 2 + 1 + (input_lengths // 100 ) * 13
80+ return (
81+ ((feat_lengths - 1 ) // 2 + 1 - 1 ) // 2
82+ + 1
83+ + (input_lengths // _AUDIO_ENCODER_CHUNK_SIZE ) * _AUDIO_ENCODER_TOKENS_PER_FULL_CHUNK
84+ )
7485
7586
7687def _configure_multimodal_attn_impl (config : object , attn_impl : str | None ) -> None :
@@ -616,26 +627,19 @@ def forward(
616627 position_ids = torch .nn .functional .pad (position_ids , (0 , sp_pad_len ), mode = "replicate" )
617628
618629 if self .config .sequence_parallel or cp_size > 1 :
619- if packed_seq_params is None :
620- visual_pos_masks , deepstack_visual_embeds = split_deepstack_embs (
621- visual_pos_masks ,
622- deepstack_visual_embeds ,
623- tp_size = tp_size ,
624- tp_rank = tp_rank ,
625- cp_size = cp_size ,
626- cp_rank = cp_rank ,
627- sequence_parallel = self .config .sequence_parallel ,
628- )
629- elif self .config .sequence_parallel :
630- visual_pos_masks , deepstack_visual_embeds = split_deepstack_embs (
631- visual_pos_masks ,
632- deepstack_visual_embeds ,
633- tp_size = tp_size ,
634- tp_rank = tp_rank ,
635- cp_size = 1 ,
636- cp_rank = 0 ,
637- sequence_parallel = self .config .sequence_parallel ,
638- )
630+ # Packed THD tensors are already CP-aware after preprocess_packed_seqs;
631+ # only the SP split remains for deepstack embeddings.
632+ split_cp_size = 1 if packed_seq_params is not None else cp_size
633+ split_cp_rank = 0 if packed_seq_params is not None else cp_rank
634+ visual_pos_masks , deepstack_visual_embeds = split_deepstack_embs (
635+ visual_pos_masks ,
636+ deepstack_visual_embeds ,
637+ tp_size = tp_size ,
638+ tp_rank = tp_rank ,
639+ cp_size = split_cp_size ,
640+ cp_rank = split_cp_rank ,
641+ sequence_parallel = self .config .sequence_parallel ,
642+ )
639643
640644 if packed_seq_params is not None and position_ids is not None :
641645 position_ids = (
@@ -649,16 +653,20 @@ def forward(
649653 .contiguous ()
650654 )
651655 attention_mask = None
652- self .language_model .rotary_pos_emb .is_thd_format = True
656+
657+ rotary_pos_emb = getattr (self .language_model , "rotary_pos_emb" , None )
658+ if rotary_pos_emb is not None :
659+ rotary_pos_emb .is_thd_format = packed_seq_params is not None
660+
661+ if packed_seq_params is not None :
662+ language_model_input_ids = lm_input_ids
663+ elif combined_embeddings is not None :
664+ language_model_input_ids = None
665+ else :
666+ language_model_input_ids = input_ids
653667
654668 return self .language_model (
655- input_ids = (
656- lm_input_ids
657- if packed_seq_params is not None
658- else None
659- if combined_embeddings is not None
660- else input_ids
661- ),
669+ input_ids = language_model_input_ids ,
662670 position_ids = position_ids ,
663671 attention_mask = attention_mask ,
664672 decoder_input = combined_embeddings ,
0 commit comments