6363from MaxText .layers .embeddings import (
6464 LLaMARotaryEmbedding ,
6565 LlamaVisionRotaryEmbedding ,
66+ Qwen3OmniMoeVisionRotaryEmbedding ,
6667 RotaryEmbedding ,
6768 YarnRotaryEmbedding ,
6869 Qwen3NextRotaryEmbedding ,
@@ -705,6 +706,14 @@ def convert_dense_general_inputs_shape(
705706 axis = canonicalize_tuple (axis )
706707 return tuple (inputs_shape [ax ] for ax in normalize_axes (axis , len (inputs_shape )))
707708
709+ def get_vision_rotary_embedding_class (self ):
710+ """Gets the rotary embedding class based on the model type."""
711+ if self .config .model_name .startswith ("qwen3-omni" ):
712+ return Qwen3OmniMoeVisionRotaryEmbedding
713+ elif self .config .model_name .startswith ("llama4" ):
714+ return LlamaVisionRotaryEmbedding
715+ raise ValueError (f"Unsupported model type for vision rotary embedding: { self .config .model_name } " )
716+
708717 def init_rotary_embedding (self ):
709718 """Initializes the rotary embeddings, handling different model types.
710719
@@ -720,15 +729,16 @@ def init_rotary_embedding(self):
720729 rope_type = self .config .rope_type .lower ()
721730 rope_use_scale = self .config .rope_use_scale
722731 if self .is_vision :
723- rotary_embedding = LlamaVisionRotaryEmbedding (
724- image_size = self .config .image_size_for_vit ,
725- patch_size = self .config .patch_size_for_vit ,
732+ rotary_embbeding_class = self .get_vision_rotary_embedding_class ()
733+ rotary_embedding = rotary_embbeding_class (
726734 hidden_size = self .config .hidden_size_for_vit ,
727735 num_attention_heads = self .config .num_attention_heads_for_vit ,
736+ spatial_merge_size = self .config .spatial_merge_size_for_vit ,
728737 rope_theta = self .config .rope_theta_for_vit ,
729738 fprop_dtype = self .dtype ,
730739 rngs = self .rngs ,
731740 )
741+
732742 elif self .config .model_name .startswith ("llama3.1" ) or rope_type .startswith ("llama3.1" ):
733743 rotary_embedding = LLaMARotaryEmbedding (
734744 min_timescale = self .config .rope_min_timescale ,
@@ -784,18 +794,27 @@ def init_rotary_embedding(self):
784794 )
785795 return rotary_embedding
786796
787- def apply_rotary_embedding (self , inputs : Array , inputs_positions : Optional [Array | None ] = None ):
797+ def apply_rotary_embedding (
798+ self , inputs : Array , inputs_positions : Optional [Array | None ] = None , rope_kwargs : dict = None
799+ ):
788800 """Applies rotary embeddings, handling different model types.
789801
790802 Args:
791803 inputs: The input tensor to apply rotary embeddings to.
792804 inputs_positions: The positions of the inputs.
793- name : A name for the embedding layer .
805+ rope_kwargs : A dictionary of keyword arguments for the rotary embedding .
794806
795807 Returns:
796808 The input tensor with rotary embeddings applied.
797809 """
798- return self .rotary_embedding (inputs , inputs_positions )
810+ if self .is_vision and self .config .model_name .startswith ("qwen3-omni" ):
811+ # For Qwen3OmniMoe vision, pass static dimensions from kwargs
812+ num_frames = rope_kwargs .get ("num_frames" )
813+ height = rope_kwargs .get ("height" )
814+ width = rope_kwargs .get ("width" )
815+ return self .rotary_embedding (inputs , num_frames , height , width )
816+ else :
817+ return self .rotary_embedding (inputs , inputs_positions )
799818
800819 def init_kv_caches (self , inputs_kv_shape : Tuple ):
801820 """Initializes KVCache.
@@ -878,6 +897,7 @@ def __call__(
878897 slot : Optional [int ] = None ,
879898 page_state : Optional [page_manager .PageState ] = None ,
880899 bidirectional_mask : Any = None ,
900+ rope_kwargs : dict = None ,
881901 ):
882902 """Applies Attention on the input data.
883903
@@ -952,8 +972,8 @@ def __call__(
952972 use_qk_norm = self .use_qk_norm and use_rope
953973
954974 if use_rope :
955- query = self .apply_rotary_embedding (query , inputs_positions = inputs_positions )
956- key = self .apply_rotary_embedding (key , inputs_positions = inputs_positions )
975+ query = self .apply_rotary_embedding (query , inputs_positions = inputs_positions , rope_kwargs = rope_kwargs )
976+ key = self .apply_rotary_embedding (key , inputs_positions = inputs_positions , rope_kwargs = rope_kwargs )
957977
958978 if use_qk_norm and is_llama4_decoder_block :
959979 l2_norm = L2Norm (eps = self .config .normalization_layer_epsilon )
0 commit comments