Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,33 @@ rope_theta_for_vit: 10000
vision_output_dim_for_vit: 4096
pixel_shuffle_ratio_for_vit: 0.5
projector_dropout_for_vit: 0.0
# Qwen3-OmniMoe vision encoder specific configs
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 512
hidden_act_for_vit: "gelu"
temporal_patch_size_for_vit: 2
num_position_embeddings_for_vit: 1024
deepstack_visual_indexes_for_vit: []

### Audio encoder configs (Qwen3-OmniMoe)
d_model_for_audio: 256
encoder_attention_heads_for_audio: 4
encoder_ffn_dim_for_audio: 512
encoder_layers_for_audio: 2
attention_dropout_for_audio: 0.0
activation_dropout_for_audio: 0.0
activation_function_for_audio: "gelu"
num_mel_bins_for_audio: 128
max_source_positions_for_audio: 1500
scale_embedding_for_audio: True
n_window_for_audio: 50
n_window_infer_for_audio: 800
conv_chunksize_for_audio: 500
downsample_hidden_size_for_audio: 256
output_dim_for_audio: 512
num_conv_layers_for_audio: 3
max_timescale_for_audio: 10000.0
max_sample_len_for_audio: 10000

# Subslice shape in the form of "x,y,z" when using pathways (single controller).
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
Expand Down
56 changes: 56 additions & 0 deletions src/MaxText/configs/models/qwen3-omni-30b-a3b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Core Architectural Parameters
decoder_block: "qwen3_moe"
base_emb_dim: 2048
base_mlp_dim: 768
base_num_query_heads: 32
base_num_kv_heads: 4
base_num_decoder_layers: 48
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 151936
normalization_layer_epsilon: 1.0e-6
use_qk_norm: True

# MoE Specific Parameters
num_experts: 128
num_experts_per_tok: 8
base_moe_mlp_dim: 768
norm_topk_prob: true

# RoPE Settings
rope_max_timescale: 10_000_000

# General Model Settings
enable_dropout: False

# Audio Encoder Configuration (need to set use_audio=true to enable)
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
d_model_for_audio: 1280
encoder_layers_for_audio: 32
encoder_attention_heads_for_audio: 20
encoder_ffn_dim_for_audio: 5120
max_source_positions_for_audio: 1500
num_mel_bins_for_audio: 128
downsample_hidden_size_for_audio: 480
output_dim_for_audio: 2048
attention_dropout_for_audio: 0.0
n_window_for_audio: 50
n_window_infer_for_audio: 400
conv_chunksize_for_audio: 500
num_conv_layers_for_audio: 3
max_timescale_for_audio: 10000.0
max_sample_len_for_audio: 10000

# Vision Encoder Configuration (need to set use_images=true to enable)
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
hidden_size_for_vit: 1152
intermediate_size_for_vit: 4304
num_attention_heads_for_vit: 16
num_hidden_layers_for_vit: 27
num_channels_for_vit: 3
patch_size_for_vit: 16
temporal_patch_size_for_vit: 2
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 2048
num_position_embeddings_for_vit: 2304
deepstack_visual_indexes_for_vit: [8, 16, 24]
87 changes: 87 additions & 0 deletions src/MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,37 @@ def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int =
)


def sinusoids_position_embedding_as_linen(
*,
length: int,
channels: int,
max_timescale: float = _MAX_WAVELENGTH,
cast_as_fprop_dtype: bool = True,
fprop_dtype: DType = jnp.bfloat16,
name: str | None = None,
):
"""Initializes the SinusoidsPositionEmbedding module and returns it as a Linen module.

Args:
length: Maximum sequence length.
channels: Number of embedding channels (must be even).
max_timescale: Maximum timescale for sinusoidal frequencies.
cast_as_fprop_dtype: Whether to cast the output to the fprop dtype.
fprop_dtype: The dtype of the output.
name: Name of the Linen module.
"""
return nnx_wrappers.to_linen(
SinusoidsPositionEmbedding,
length=length,
channels=channels,
max_timescale=max_timescale,
cast_as_fprop_dtype=cast_as_fprop_dtype,
fprop_dtype=fprop_dtype,
metadata_fn=variable_to_logically_partitioned,
name=name,
)


@dataclasses.dataclass(repr=False)
class PositionalEmbedding(nnx.Module):
"""A layer that adds sinusoidal positional embeddings to the input.
Expand Down Expand Up @@ -918,3 +949,59 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
output = output.astype(self.fprop_dtype)

return output


@dataclasses.dataclass(repr=False)
class SinusoidsPositionEmbedding(nnx.Module):
"""Sinusoidal position embeddings with precomputed table for efficient lookup."""

def __init__(
self,
length: int,
channels: int,
max_timescale: float = _MAX_WAVELENGTH,
cast_as_fprop_dtype: bool = True,
fprop_dtype: DType = jnp.bfloat16,
*,
rngs: nnx.Rngs = None,
):
"""Precompute sinusoidal position embeddings for all positions."""
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")

self.length = length
self.channels = channels
self.max_timescale = max_timescale
self.cast_as_fprop_dtype = cast_as_fprop_dtype
self.fprop_dtype = fprop_dtype

log_timescale_increment = jnp.log(max_timescale) / (channels // 2 - 1)
inv_timescales = jnp.exp(
-log_timescale_increment * jnp.arange(channels // 2, dtype=jnp.float32)
)
scaled_time = (
jnp.arange(length, dtype=jnp.float32)[:, None] * inv_timescales[None, :]
)
positional_embedding = jnp.concatenate(
[jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1
)

self.positional_embedding = nnx.Variable(positional_embedding)

def __call__(self, seqlen: int) -> Array:
"""Return positional embeddings for given sequence length.

Args:
seqlen: Sequence length to retrieve embeddings for (must be <= self.length)

Returns:
Positional embeddings of shape (seqlen, channels)
"""
output = jax.lax.dynamic_slice(
self.positional_embedding.value,
start_indices=(0, 0),
slice_sizes=(seqlen, self.channels),
)
if self.cast_as_fprop_dtype:
output = output.astype(self.fprop_dtype)
return output
Loading