Skip to content

Commit 2515f9c

Browse files
committed
Add qwen3 omni vision encoder
1 parent 341061f commit 2515f9c

File tree

9 files changed

+4658
-2114
lines changed

9 files changed

+4658
-2114
lines changed

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,13 @@ vision_output_dim_for_vit: 4096
884884
pixel_shuffle_ratio_for_vit: 0.5
885885
projector_dropout_for_vit: 0.0
886886

887+
# Qwen3-OmniMoe vision encoder
888+
spatial_merge_size_for_vit: 2
889+
out_hidden_size_for_vit: 512
890+
temporal_patch_size_for_vit: 2
891+
num_position_embeddings_for_vit: 1024
892+
deepstack_visual_indexes_for_vit: []
893+
887894
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
888895
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
889896
subslice_shape: ""

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,25 @@ base_moe_mlp_dim: 768
3434
norm_topk_prob: true
3535

3636
# RoPE Settings
37-
rope_max_timescale: 10_000_000
37+
rope_max_timescale: 1_000_000
38+
max_position_embeddings: 65536
3839

3940
# General Model Settings
4041
enable_dropout: False
42+
43+
# Vision Encoder Configuration
44+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
45+
image_size_for_vit: 768
46+
hidden_size_for_vit: 1152
47+
intermediate_size_for_vit: 4304
48+
num_attention_heads_for_vit: 16
49+
num_hidden_layers_for_vit: 27
50+
num_channels_for_vit: 3
51+
patch_size_for_vit: 16
52+
temporal_patch_size_for_vit: 2
53+
spatial_merge_size_for_vit: 2
54+
out_hidden_size_for_vit: 2048
55+
num_position_embeddings_for_vit: 2304
56+
deepstack_visual_indexes_for_vit: [7, 16, 24]
57+
58+
use_multimodal: true

src/MaxText/layers/attentions.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from MaxText.layers.embeddings import (
6464
LLaMARotaryEmbedding,
6565
LlamaVisionRotaryEmbedding,
66+
Qwen3OmniMoeVisionRotaryEmbedding,
6667
RotaryEmbedding,
6768
YarnRotaryEmbedding,
6869
Qwen3NextRotaryEmbedding,
@@ -705,6 +706,14 @@ def convert_dense_general_inputs_shape(
705706
axis = canonicalize_tuple(axis)
706707
return tuple(inputs_shape[ax] for ax in normalize_axes(axis, len(inputs_shape)))
707708

709+
def get_vision_rotary_embedding_class(self):
710+
"""Gets the rotary embedding class based on the model type."""
711+
if self.config.model_name.startswith("qwen3-omni"):
712+
return Qwen3OmniMoeVisionRotaryEmbedding
713+
elif self.config.model_name.startswith("llama4"):
714+
return LlamaVisionRotaryEmbedding
715+
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")
716+
708717
def init_rotary_embedding(self):
709718
"""Initializes the rotary embeddings, handling different model types.
710719
@@ -720,15 +729,16 @@ def init_rotary_embedding(self):
720729
rope_type = self.config.rope_type.lower()
721730
rope_use_scale = self.config.rope_use_scale
722731
if self.is_vision:
723-
rotary_embedding = LlamaVisionRotaryEmbedding(
724-
image_size=self.config.image_size_for_vit,
725-
patch_size=self.config.patch_size_for_vit,
732+
rotary_embbeding_class = self.get_vision_rotary_embedding_class()
733+
rotary_embedding = rotary_embbeding_class(
726734
hidden_size=self.config.hidden_size_for_vit,
727735
num_attention_heads=self.config.num_attention_heads_for_vit,
736+
spatial_merge_size=self.config.spatial_merge_size_for_vit,
728737
rope_theta=self.config.rope_theta_for_vit,
729738
fprop_dtype=self.dtype,
730739
rngs=self.rngs,
731740
)
741+
732742
elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
733743
rotary_embedding = LLaMARotaryEmbedding(
734744
min_timescale=self.config.rope_min_timescale,
@@ -784,18 +794,27 @@ def init_rotary_embedding(self):
784794
)
785795
return rotary_embedding
786796

787-
def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None):
797+
def apply_rotary_embedding(
798+
self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict = None
799+
):
788800
"""Applies rotary embeddings, handling different model types.
789801
790802
Args:
791803
inputs: The input tensor to apply rotary embeddings to.
792804
inputs_positions: The positions of the inputs.
793-
name: A name for the embedding layer.
805+
rope_kwargs: A dictionary of keyword arguments for the rotary embedding.
794806
795807
Returns:
796808
The input tensor with rotary embeddings applied.
797809
"""
798-
return self.rotary_embedding(inputs, inputs_positions)
810+
if self.is_vision and self.config.model_name.startswith("qwen3-omni"):
811+
# For Qwen3OmniMoe vision, pass static dimensions from kwargs
812+
num_frames = rope_kwargs.get("num_frames")
813+
height = rope_kwargs.get("height")
814+
width = rope_kwargs.get("width")
815+
return self.rotary_embedding(inputs, num_frames, height, width)
816+
else:
817+
return self.rotary_embedding(inputs, inputs_positions)
799818

800819
def init_kv_caches(self, inputs_kv_shape: Tuple):
801820
"""Initializes KVCache.
@@ -878,6 +897,7 @@ def __call__(
878897
slot: Optional[int] = None,
879898
page_state: Optional[page_manager.PageState] = None,
880899
bidirectional_mask: Any = None,
900+
rope_kwargs: dict = None,
881901
):
882902
"""Applies Attention on the input data.
883903
@@ -952,8 +972,8 @@ def __call__(
952972
use_qk_norm = self.use_qk_norm and use_rope
953973

954974
if use_rope:
955-
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions)
956-
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions)
975+
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
976+
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
957977

958978
if use_qk_norm and is_llama4_decoder_block:
959979
l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon)

src/MaxText/layers/decoders.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,23 @@ def __call__(
9797
)
9898

9999
if self.model_mode == MODEL_MODE_PREFILL:
100-
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
100+
logical_axis_names = (
101+
"activation_batch",
102+
"prefill_activation_length",
103+
"activation_embed",
104+
)
101105
elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN:
102-
logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed")
106+
logical_axis_names = (
107+
"activation_batch_no_exp",
108+
"activation_length",
109+
"activation_embed",
110+
)
103111
else:
104-
logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed")
112+
logical_axis_names = (
113+
"activation_batch",
114+
"activation_length_no_exp",
115+
"activation_embed",
116+
)
105117

106118
if model_mode == MODEL_MODE_PREFILL:
107119
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
@@ -235,7 +247,11 @@ def __call__(
235247
) -> jnp.ndarray:
236248
for lyr in range(self.num_decoder_layers):
237249
inputs = self.decoder_layer(
238-
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
250+
config=self.config,
251+
mesh=self.mesh,
252+
name=f"layers_{lyr}",
253+
quant=self.quant,
254+
model_mode=model_mode,
239255
)(
240256
inputs,
241257
decoder_segment_ids,
@@ -269,7 +285,10 @@ def setup(self):
269285
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
270286
remat_policy = self.get_remat_policy()
271287
self.pipeline_module = pipeline.Pipeline(
272-
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
288+
config=self.config,
289+
mesh=self.mesh,
290+
layers=pipeline_stage_module,
291+
remat_policy=remat_policy,
273292
)
274293

275294
def minimal_policy(self, with_context=False):
@@ -339,7 +358,11 @@ def get_remat_policy(self):
339358
elif cfg.remat_policy == "qkv_proj_offloaded":
340359
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
341360
names_which_can_be_saved=[],
342-
names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"],
361+
names_which_can_be_offloaded=[
362+
"query_proj",
363+
"value_proj",
364+
"key_proj",
365+
],
343366
offload_src="device",
344367
offload_dst="pinned_host",
345368
)
@@ -395,7 +418,10 @@ def get_decoder_layers(self):
395418
return [mixtral.MixtralDecoderLayerToLinen]
396419
case DecoderBlockType.DEEPSEEK:
397420
if self.config.use_batch_split_schedule:
398-
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
421+
return [
422+
deepseek_batchsplit.DeepSeekDenseLayer,
423+
deepseek_batchsplit.DeepSeekMoELayer,
424+
]
399425
else:
400426
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
401427
case DecoderBlockType.GEMMA:
@@ -447,7 +473,10 @@ def map_fn(path, value):
447473
block_layer,
448474
prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config),
449475
policy=policy,
450-
static_argnums=(4, 5), # Deterministic and model mode are static arguments.
476+
static_argnums=(
477+
4,
478+
5,
479+
), # Deterministic and model mode are static arguments.
451480
)
452481
RemattedBlockLayers.append(layer)
453482
return RemattedBlockLayers
@@ -473,11 +502,25 @@ def get_norm_layer(self, num_features: int):
473502
):
474503
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
475504
elif self.config.decoder_block == DecoderBlockType.GPT3:
476-
return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True)
505+
return functools.partial(
506+
gpt3.gpt3_layer_norm,
507+
num_features=num_features,
508+
reductions_in_fp32=False,
509+
use_bias=True,
510+
)
477511
else:
478512
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
479513

480-
def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs):
514+
def scan_decoder_layers(
515+
self,
516+
cfg,
517+
decoder_layer,
518+
length,
519+
metadata_axis_name,
520+
mesh,
521+
in_axes_tuple,
522+
**kwargs,
523+
):
481524
"""scan decoder layers, calls `flax.linen.transforms.scan`"""
482525
initializing = self.is_mutable_collection("params")
483526
params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)
@@ -500,7 +543,11 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
500543
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
501544
)
502545
return scan_fn(
503-
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
546+
config=cfg,
547+
mesh=mesh,
548+
name=metadata_axis_name,
549+
quant=self.quant,
550+
**kwargs, # pytype: disable=wrong-keyword-args
504551
)
505552

506553
def get_pipeline_stage_module(self, decoder_blocks):
@@ -558,7 +605,13 @@ def _apply_embedding(
558605

559606
# Merge the image embeddings with the text embeddings for multimodal models
560607
if image_embeddings is not None and cfg.use_multimodal:
561-
if cfg.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e"]:
608+
if cfg.model_name in [
609+
"gemma3-4b",
610+
"gemma3-12b",
611+
"gemma3-27b",
612+
"llama4-17b-16e",
613+
"llama4-17b-128e",
614+
]:
562615
y = multimodal_utils.merge_mm_embeddings(
563616
text_embeddings=y,
564617
vision_embeddings=image_embeddings,
@@ -751,7 +804,10 @@ def __call__(
751804
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
752805
if remaining_layers > 0:
753806
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
754-
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
807+
with (
808+
self.mesh,
809+
nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp),
810+
):
755811
y, _ = self.scan_decoder_layers(
756812
cfg,
757813
RemattedBlockLayers[0],
@@ -838,7 +894,11 @@ def __call__(
838894
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
839895
for index in range(num_layers):
840896
y = layer(
841-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
897+
config=cfg,
898+
mesh=mesh,
899+
name=f"{layer_prefix}_{index}",
900+
quant=self.quant,
901+
model_mode=self.model_mode,
842902
)(
843903
y,
844904
decoder_segment_ids,
@@ -868,7 +928,12 @@ def __call__(
868928
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
869929
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
870930
layer = RemattedBlockLayer(
871-
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
931+
config=cfg,
932+
mesh=mesh,
933+
name=f"layers_{lyr}",
934+
quant=self.quant,
935+
model_mode=self.model_mode,
936+
**layer_kwargs,
872937
)
873938
y = layer(
874939
y,
@@ -952,7 +1017,12 @@ def _apply_gemma3_scanned_blocks(
9521017
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
9531018
# pytype: disable=wrong-keyword-args
9541019
layer = RemattedGemma3Block(
955-
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
1020+
config=cfg,
1021+
mesh=mesh,
1022+
quant=self.quant,
1023+
model_mode=self.model_mode,
1024+
name="layers_remainder",
1025+
**rem_layer_kwargs,
9561026
)
9571027
y, _ = layer(
9581028
y,

0 commit comments

Comments
 (0)