Skip to content

Commit c79b379

Browse files
committed
Add qwen3 omni vision encoder
1 parent 42cb7ed commit c79b379

File tree

8 files changed

+1896
-12
lines changed

8 files changed

+1896
-12
lines changed

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,13 @@ vision_output_dim_for_vit: 4096
853853
pixel_shuffle_ratio_for_vit: 0.5
854854
projector_dropout_for_vit: 0.0
855855

856+
# Qwen3-OmniMoe vision encoder
857+
spatial_merge_size_for_vit: 2
858+
out_hidden_size_for_vit: 512
859+
temporal_patch_size_for_vit: 2
860+
num_position_embeddings_for_vit: 1024
861+
deepstack_visual_indexes_for_vit: []
862+
856863
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
857864
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
858865
subslice_shape: ""
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Core Architectural Parameters
2+
decoder_block: "qwen3_moe"
3+
base_emb_dim: 2048
4+
base_mlp_dim: 768
5+
base_num_query_heads: 32
6+
base_num_kv_heads: 4
7+
base_num_decoder_layers: 48
8+
head_dim: 128
9+
mlp_activations: ["silu", "linear"]
10+
vocab_size: 152064
11+
normalization_layer_epsilon: 1.0e-6
12+
use_qk_norm: True
13+
14+
# MoE Specific Parameters
15+
num_experts: 128
16+
num_experts_per_tok: 8
17+
base_moe_mlp_dim: 768
18+
norm_topk_prob: true
19+
20+
# RoPE Settings
21+
rope_max_timescale: 1_000_000
22+
max_position_embeddings: 65536
23+
24+
# General Model Settings
25+
enable_dropout: False
26+
27+
# Vision Encoder Configuration
28+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
29+
image_size_for_vit: 768
30+
hidden_size_for_vit: 1152
31+
intermediate_size_for_vit: 4304
32+
num_attention_heads_for_vit: 16
33+
num_hidden_layers_for_vit: 27
34+
num_channels_for_vit: 3
35+
patch_size_for_vit: 16
36+
temporal_patch_size_for_vit: 2
37+
spatial_merge_size_for_vit: 2
38+
out_hidden_size_for_vit: 2048
39+
num_position_embeddings_for_vit: 2304
40+
deepstack_visual_indexes_for_vit: [8, 16, 24]

src/MaxText/layers/attentions.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from MaxText.layers.embeddings import (
6464
LLaMARotaryEmbedding,
6565
LlamaVisionRotaryEmbedding,
66+
Qwen3OmniMoeVisionRotaryEmbedding,
6667
RotaryEmbedding,
6768
YarnRotaryEmbedding,
6869
)
@@ -661,6 +662,14 @@ def convert_dense_general_inputs_shape(
661662
axis = canonicalize_tuple(axis)
662663
return tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape)))
663664

665+
def get_vision_rotary_embedding_class(self):
666+
"""Gets the rotary embedding class based on the model type."""
667+
if self.config.model_name.startswith("qwen3-omni"):
668+
return Qwen3OmniMoeVisionRotaryEmbedding
669+
elif self.config.model_name.startswith("llama4"):
670+
return LlamaVisionRotaryEmbedding
671+
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")
672+
664673
def init_rotary_embedding(self):
665674
"""Initializes the rotary embeddings, handling different model types.
666675
@@ -676,15 +685,16 @@ def init_rotary_embedding(self):
676685
rope_type = self.config.rope_type.lower()
677686
rope_use_scale = self.config.rope_use_scale
678687
if self.is_vision:
679-
rotary_embedding = LlamaVisionRotaryEmbedding(
680-
image_size=self.config.image_size_for_vit,
681-
patch_size=self.config.patch_size_for_vit,
688+
rotary_embbeding_class = self.get_vision_rotary_embedding_class()
689+
rotary_embedding = rotary_embbeding_class(
682690
hidden_size=self.config.hidden_size_for_vit,
683691
num_attention_heads=self.config.num_attention_heads_for_vit,
692+
spatial_merge_size=self.config.spatial_merge_size_for_vit,
684693
rope_theta=self.config.rope_theta_for_vit,
685694
fprop_dtype=self.dtype,
686695
rngs=self.rngs,
687696
)
697+
688698
elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
689699
rotary_embedding = LLaMARotaryEmbedding(
690700
min_timescale=self.config.rope_min_timescale,
@@ -730,18 +740,27 @@ def init_rotary_embedding(self):
730740
)
731741
return rotary_embedding
732742

733-
def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None):
743+
def apply_rotary_embedding(
744+
self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict = None
745+
):
734746
"""Applies rotary embeddings, handling different model types.
735747
736748
Args:
737749
inputs: The input tensor to apply rotary embeddings to.
738750
inputs_positions: The positions of the inputs.
739-
name: A name for the embedding layer.
751+
rope_kwargs: A dictionary of keyword arguments for the rotary embedding.
740752
741753
Returns:
742754
The input tensor with rotary embeddings applied.
743755
"""
744-
return self.rotary_embedding(inputs, inputs_positions)
756+
if self.is_vision and self.config.model_name.startswith("qwen3-omni"):
757+
# For Qwen3OmniMoe vision, pass static dimensions from kwargs
758+
num_frames = rope_kwargs.get("num_frames")
759+
height = rope_kwargs.get("height")
760+
width = rope_kwargs.get("width")
761+
return self.rotary_embedding(inputs, num_frames, height, width)
762+
else:
763+
return self.rotary_embedding(inputs, inputs_positions)
745764

746765
def init_kv_caches(self, inputs_kv_shape: Tuple):
747766
"""Initializes KVCache.
@@ -823,6 +842,7 @@ def __call__(
823842
slot: Optional[int] = None,
824843
page_state: Optional[page_manager.PageState] = None,
825844
bidirectional_mask: Any = None,
845+
rope_kwargs: dict = None,
826846
):
827847
"""Applies Attention on the input data.
828848
@@ -886,8 +906,8 @@ def __call__(
886906
use_qk_norm = self.use_qk_norm and use_rope
887907

888908
if use_rope:
889-
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions)
890-
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions)
909+
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
910+
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
891911

892912
if use_qk_norm and is_llama4_decoder_block:
893913
l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon)

0 commit comments

Comments
 (0)