Skip to content

Commit 9f26995

Browse files
committed
Add Qwen3 Omni Vision Encoder
1 parent 341061f commit 9f26995

File tree

8 files changed

+1934
-15
lines changed

8 files changed

+1934
-15
lines changed

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,13 @@ vision_output_dim_for_vit: 4096
884884
pixel_shuffle_ratio_for_vit: 0.5
885885
projector_dropout_for_vit: 0.0
886886

887+
# Qwen3-OmniMoe vision encoder
888+
spatial_merge_size_for_vit: 2
889+
out_hidden_size_for_vit: 512
890+
temporal_patch_size_for_vit: 2
891+
num_position_embeddings_for_vit: 1024
892+
deepstack_visual_indexes_for_vit: []
893+
887894
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
888895
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
889896
subslice_shape: ""

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,25 @@ base_moe_mlp_dim: 768
3434
norm_topk_prob: true
3535

3636
# RoPE Settings
37-
rope_max_timescale: 10_000_000
37+
rope_max_timescale: 1_000_000
38+
max_position_embeddings: 65536
3839

3940
# General Model Settings
4041
enable_dropout: False
42+
43+
# Vision Encoder Configuration
44+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
45+
image_size_for_vit: 768
46+
hidden_size_for_vit: 1152
47+
intermediate_size_for_vit: 4304
48+
num_attention_heads_for_vit: 16
49+
num_hidden_layers_for_vit: 27
50+
num_channels_for_vit: 3
51+
patch_size_for_vit: 16
52+
temporal_patch_size_for_vit: 2
53+
spatial_merge_size_for_vit: 2
54+
out_hidden_size_for_vit: 2048
55+
num_position_embeddings_for_vit: 2304
56+
deepstack_visual_indexes_for_vit: [7, 16, 24]
57+
58+
use_multimodal: true

src/MaxText/layers/attentions.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from MaxText.layers.embeddings import (
6464
LLaMARotaryEmbedding,
6565
LlamaVisionRotaryEmbedding,
66+
Qwen3OmniMoeVisionRotaryEmbedding,
6667
RotaryEmbedding,
6768
YarnRotaryEmbedding,
6869
Qwen3NextRotaryEmbedding,
@@ -705,6 +706,14 @@ def convert_dense_general_inputs_shape(
705706
axis = canonicalize_tuple(axis)
706707
return tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape)))
707708

709+
def get_vision_rotary_embedding_class(self):
710+
"""Gets the rotary embedding class based on the model type."""
711+
if self.config.model_name.startswith("qwen3-omni"):
712+
return Qwen3OmniMoeVisionRotaryEmbedding
713+
elif self.config.model_name.startswith("llama4"):
714+
return LlamaVisionRotaryEmbedding
715+
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")
716+
708717
def init_rotary_embedding(self):
709718
"""Initializes the rotary embeddings, handling different model types.
710719
@@ -720,15 +729,16 @@ def init_rotary_embedding(self):
720729
rope_type = self.config.rope_type.lower()
721730
rope_use_scale = self.config.rope_use_scale
722731
if self.is_vision:
723-
rotary_embedding = LlamaVisionRotaryEmbedding(
724-
image_size=self.config.image_size_for_vit,
725-
patch_size=self.config.patch_size_for_vit,
732+
rotary_embbeding_class = self.get_vision_rotary_embedding_class()
733+
rotary_embedding = rotary_embbeding_class(
726734
hidden_size=self.config.hidden_size_for_vit,
727735
num_attention_heads=self.config.num_attention_heads_for_vit,
736+
spatial_merge_size=self.config.spatial_merge_size_for_vit,
728737
rope_theta=self.config.rope_theta_for_vit,
729738
fprop_dtype=self.dtype,
730739
rngs=self.rngs,
731740
)
741+
732742
elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
733743
rotary_embedding = LLaMARotaryEmbedding(
734744
min_timescale=self.config.rope_min_timescale,
@@ -784,18 +794,27 @@ def init_rotary_embedding(self):
784794
)
785795
return rotary_embedding
786796

787-
def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None):
797+
def apply_rotary_embedding(
798+
self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict = None
799+
):
788800
"""Applies rotary embeddings, handling different model types.
789801
790802
Args:
791803
inputs: The input tensor to apply rotary embeddings to.
792804
inputs_positions: The positions of the inputs.
793-
name: A name for the embedding layer.
805+
rope_kwargs: A dictionary of keyword arguments for the rotary embedding.
794806
795807
Returns:
796808
The input tensor with rotary embeddings applied.
797809
"""
798-
return self.rotary_embedding(inputs, inputs_positions)
810+
if self.is_vision and self.config.model_name.startswith("qwen3-omni"):
811+
# For Qwen3OmniMoe vision, pass static dimensions from kwargs
812+
num_frames = rope_kwargs.get("num_frames")
813+
height = rope_kwargs.get("height")
814+
width = rope_kwargs.get("width")
815+
return self.rotary_embedding(inputs, num_frames, height, width)
816+
else:
817+
return self.rotary_embedding(inputs, inputs_positions)
799818

800819
def init_kv_caches(self, inputs_kv_shape: Tuple):
801820
"""Initializes KVCache.
@@ -878,6 +897,7 @@ def __call__(
878897
slot: Optional[int] = None,
879898
page_state: Optional[page_manager.PageState] = None,
880899
bidirectional_mask: Any = None,
900+
rope_kwargs: dict = None,
881901
):
882902
"""Applies Attention on the input data.
883903
@@ -952,8 +972,8 @@ def __call__(
952972
use_qk_norm = self.use_qk_norm and use_rope
953973

954974
if use_rope:
955-
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions)
956-
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions)
975+
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
976+
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
957977

958978
if use_qk_norm and is_llama4_decoder_block:
959979
l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon)

0 commit comments

Comments
 (0)