Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,13 @@ vision_output_dim_for_vit: 4096
pixel_shuffle_ratio_for_vit: 0.5
projector_dropout_for_vit: 0.0

# Qwen3-OmniMoe vision encoder
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 512
temporal_patch_size_for_vit: 2
num_position_embeddings_for_vit: 1024
deepstack_visual_indexes_for_vit: []

# Subslice shape in the form of "x,y,z" when using pathways (single controller).
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
subslice_shape: ""
Expand Down
20 changes: 19 additions & 1 deletion src/MaxText/configs/models/qwen3-omni-30b-a3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,25 @@ base_moe_mlp_dim: 768
norm_topk_prob: true

# RoPE Settings
rope_max_timescale: 10_000_000
rope_max_timescale: 1_000_000
max_position_embeddings: 65536

# General Model Settings
enable_dropout: False

# Vision Encoder Configuration
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
image_size_for_vit: 768
hidden_size_for_vit: 1152
intermediate_size_for_vit: 4304
num_attention_heads_for_vit: 16
num_hidden_layers_for_vit: 27
num_channels_for_vit: 3
patch_size_for_vit: 16
temporal_patch_size_for_vit: 2
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 2048
num_position_embeddings_for_vit: 2304
deepstack_visual_indexes_for_vit: [7, 16, 24]

use_multimodal: true
36 changes: 28 additions & 8 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from MaxText.layers.embeddings import (
LLaMARotaryEmbedding,
LlamaVisionRotaryEmbedding,
Qwen3OmniMoeVisionRotaryEmbedding,
RotaryEmbedding,
YarnRotaryEmbedding,
Qwen3NextRotaryEmbedding,
Expand Down Expand Up @@ -705,6 +706,14 @@ def convert_dense_general_inputs_shape(
axis = canonicalize_tuple(axis)
return tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape)))

def get_vision_rotary_embedding_class(self):
"""Gets the rotary embedding class based on the model type."""
if self.config.model_name.startswith("qwen3-omni"):
return Qwen3OmniMoeVisionRotaryEmbedding
elif self.config.model_name.startswith("llama4"):
return LlamaVisionRotaryEmbedding
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")

def init_rotary_embedding(self):
"""Initializes the rotary embeddings, handling different model types.

Expand All @@ -720,15 +729,16 @@ def init_rotary_embedding(self):
rope_type = self.config.rope_type.lower()
rope_use_scale = self.config.rope_use_scale
if self.is_vision:
rotary_embedding = LlamaVisionRotaryEmbedding(
image_size=self.config.image_size_for_vit,
patch_size=self.config.patch_size_for_vit,
rotary_embbeding_class = self.get_vision_rotary_embedding_class()
rotary_embedding = rotary_embbeding_class(
hidden_size=self.config.hidden_size_for_vit,
num_attention_heads=self.config.num_attention_heads_for_vit,
spatial_merge_size=self.config.spatial_merge_size_for_vit,
rope_theta=self.config.rope_theta_for_vit,
fprop_dtype=self.dtype,
rngs=self.rngs,
)

elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
rotary_embedding = LLaMARotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
Expand Down Expand Up @@ -784,18 +794,27 @@ def init_rotary_embedding(self):
)
return rotary_embedding

def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None):
def apply_rotary_embedding(
self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict = None
):
"""Applies rotary embeddings, handling different model types.

Args:
inputs: The input tensor to apply rotary embeddings to.
inputs_positions: The positions of the inputs.
name: A name for the embedding layer.
rope_kwargs: A dictionary of keyword arguments for the rotary embedding.

Returns:
The input tensor with rotary embeddings applied.
"""
return self.rotary_embedding(inputs, inputs_positions)
if self.is_vision and self.config.model_name.startswith("qwen3-omni"):
# For Qwen3OmniMoe vision, pass static dimensions from kwargs
num_frames = rope_kwargs.get("num_frames")
height = rope_kwargs.get("height")
width = rope_kwargs.get("width")
return self.rotary_embedding(inputs, num_frames, height, width)
else:
return self.rotary_embedding(inputs, inputs_positions)

def init_kv_caches(self, inputs_kv_shape: Tuple):
"""Initializes KVCache.
Expand Down Expand Up @@ -878,6 +897,7 @@ def __call__(
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
bidirectional_mask: Any = None,
rope_kwargs: dict = None,
):
"""Applies Attention on the input data.

Expand Down Expand Up @@ -952,8 +972,8 @@ def __call__(
use_qk_norm = self.use_qk_norm and use_rope

if use_rope:
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions)
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions)
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)

if use_qk_norm and is_llama4_decoder_block:
l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon)
Expand Down
Loading
Loading