6363from MaxText .layers .embeddings import (
6464 LLaMARotaryEmbedding ,
6565 LlamaVisionRotaryEmbedding ,
66+ Qwen3OmniMoeVisionRotaryEmbedding ,
6667 RotaryEmbedding ,
6768 YarnRotaryEmbedding ,
6869)
@@ -661,6 +662,14 @@ def convert_dense_general_inputs_shape(
661662 axis = canonicalize_tuple (axis )
662663 return tuple (inputs_shape [ax ] for ax in normalize_axes (axis , len (inputs_shape )))
663664
665+ def get_vision_rotary_embedding_class (self ):
666+ """Gets the rotary embedding class based on the model type."""
667+ if self .config .model_name .startswith ("qwen3-omni" ):
668+ return Qwen3OmniMoeVisionRotaryEmbedding
669+ elif self .config .model_name .startswith ("llama4" ):
670+ return LlamaVisionRotaryEmbedding
671+ raise ValueError (f"Unsupported model type for vision rotary embedding: { self .config .model_name } " )
672+
664673 def init_rotary_embedding (self ):
665674 """Initializes the rotary embeddings, handling different model types.
666675
@@ -676,15 +685,16 @@ def init_rotary_embedding(self):
676685 rope_type = self .config .rope_type .lower ()
677686 rope_use_scale = self .config .rope_use_scale
678687 if self .is_vision :
679- rotary_embedding = LlamaVisionRotaryEmbedding (
680- image_size = self .config .image_size_for_vit ,
681- patch_size = self .config .patch_size_for_vit ,
688+ rotary_embbeding_class = self .get_vision_rotary_embedding_class ()
689+ rotary_embedding = rotary_embbeding_class (
682690 hidden_size = self .config .hidden_size_for_vit ,
683691 num_attention_heads = self .config .num_attention_heads_for_vit ,
692+ spatial_merge_size = self .config .spatial_merge_size_for_vit ,
684693 rope_theta = self .config .rope_theta_for_vit ,
685694 fprop_dtype = self .dtype ,
686695 rngs = self .rngs ,
687696 )
697+
688698 elif self .config .model_name .startswith ("llama3.1" ) or rope_type .startswith ("llama3.1" ):
689699 rotary_embedding = LLaMARotaryEmbedding (
690700 min_timescale = self .config .rope_min_timescale ,
@@ -730,18 +740,27 @@ def init_rotary_embedding(self):
730740 )
731741 return rotary_embedding
732742
733- def apply_rotary_embedding (self , inputs : Array , inputs_positions : Optional [Array | None ] = None ):
743+ def apply_rotary_embedding (
744+ self , inputs : Array , inputs_positions : Optional [Array | None ] = None , rope_kwargs : dict = None
745+ ):
734746 """Applies rotary embeddings, handling different model types.
735747
736748 Args:
737749 inputs: The input tensor to apply rotary embeddings to.
738750 inputs_positions: The positions of the inputs.
739- name : A name for the embedding layer .
751+ rope_kwargs : A dictionary of keyword arguments for the rotary embedding .
740752
741753 Returns:
742754 The input tensor with rotary embeddings applied.
743755 """
744- return self .rotary_embedding (inputs , inputs_positions )
756+ if self .is_vision and self .config .model_name .startswith ("qwen3-omni" ):
757+ # For Qwen3OmniMoe vision, pass static dimensions from kwargs
758+ num_frames = rope_kwargs .get ("num_frames" )
759+ height = rope_kwargs .get ("height" )
760+ width = rope_kwargs .get ("width" )
761+ return self .rotary_embedding (inputs , num_frames , height , width )
762+ else :
763+ return self .rotary_embedding (inputs , inputs_positions )
745764
746765 def init_kv_caches (self , inputs_kv_shape : Tuple ):
747766 """Initializes KVCache.
@@ -823,6 +842,7 @@ def __call__(
823842 slot : Optional [int ] = None ,
824843 page_state : Optional [page_manager .PageState ] = None ,
825844 bidirectional_mask : Any = None ,
845+ rope_kwargs : dict = None ,
826846 ):
827847 """Applies Attention on the input data.
828848
@@ -886,8 +906,8 @@ def __call__(
886906 use_qk_norm = self .use_qk_norm and use_rope
887907
888908 if use_rope :
889- query = self .apply_rotary_embedding (query , inputs_positions = inputs_positions )
890- key = self .apply_rotary_embedding (key , inputs_positions = inputs_positions )
909+ query = self .apply_rotary_embedding (query , inputs_positions = inputs_positions , rope_kwargs = rope_kwargs )
910+ key = self .apply_rotary_embedding (key , inputs_positions = inputs_positions , rope_kwargs = rope_kwargs )
891911
892912 if use_qk_norm and is_llama4_decoder_block :
893913 l2_norm = L2Norm (eps = self .config .normalization_layer_epsilon )
0 commit comments