Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 4 additions & 184 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,22 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
Qwen3OmniMoeTalkerConfig,
)
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeAudioEncoder,
)
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsPP,
)
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5OmniThinkerDummyInputsBuilder,
)
from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock
from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
Expand Down Expand Up @@ -512,162 +506,14 @@ def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))


class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module):
"""
Wrapper that combines shared_expert MLP with its sigmoid gate.

This matches the HuggingFace weight structure where:
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
- mlp.shared_expert_gate.weight (sibling, not child)

The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x).

It also exposes the underlying shared_expert interface to keep
compatibility with backends that split shared-expert computation.
"""

def __init__(
self,
shared_expert: Qwen3MoeMLP,
shared_expert_gate: nn.Linear,
):
super().__init__()
self._shared_expert = shared_expert
self._shared_expert_gate = shared_expert_gate

@property
def gate_up_proj(self):
return self._shared_expert.gate_up_proj

@property
def down_proj(self):
return self._shared_expert.down_proj

@property
def act_fn(self):
return self._shared_expert.act_fn

def expert_gate(self, x: torch.Tensor):
gate_out = self._shared_expert_gate(x)
if isinstance(gate_out, tuple):
return gate_out
return gate_out, None

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._shared_expert(x)
gate_out = self._shared_expert_gate(x)
if isinstance(gate_out, tuple):
gate_out = gate_out[0]
gate_values = F.sigmoid(gate_out) # [batch, 1]
return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden]


class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module):
"""
Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support.

This block uses SharedFusedMoE to efficiently compute both routed experts
and the shared expert, potentially overlapping computation with communication.

Weight structure matches HuggingFace:
- mlp.gate.weight (router)
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
- mlp.shared_expert_gate.weight
- mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight
"""

def __init__(
self,
config: Qwen3OmniMoeTalkerConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
text_config = config.text_config
self.tp_size = get_tensor_model_parallel_world_size()

if self.tp_size > text_config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}."
)

# Router gate for selecting top-k experts
self.gate = ReplicatedLinear(
text_config.hidden_size,
text_config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)

# Shared expert MLP (matches HF: mlp.shared_expert.*)
if text_config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3MoeMLP(
hidden_size=text_config.hidden_size,
intermediate_size=text_config.shared_expert_intermediate_size,
hidden_act=text_config.hidden_act,
quant_config=quant_config,
reduce_results=False, # Don't reduce, we'll handle it
prefix=f"{prefix}.shared_expert",
)
# Shared expert gate (matches HF: mlp.shared_expert_gate.weight)
# This is a sibling of shared_expert, not a child
self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False)
# Create wrapper for SharedFusedMoE
self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper(
self.shared_expert, self.shared_expert_gate
)
else:
self.shared_expert = None
self.shared_expert_gate = None
self._shared_expert_wrapper = None

# Fused MoE with shared expert support
self.experts = SharedFusedMoE(
shared_experts=self._shared_expert_wrapper,
num_experts=text_config.num_experts,
top_k=text_config.num_experts_per_tok,
hidden_size=text_config.hidden_size,
intermediate_size=text_config.moe_intermediate_size,
reduce_results=False, # We'll reduce manually after combining
renormalize=text_config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)

# Compute router logits
router_logits, _ = self.gate(hidden_states)

# Forward through SharedFusedMoE
# Returns (shared_out, fused_out) when shared_expert is present
final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)

# Combine shared and routed expert outputs
if self._shared_expert_wrapper is not None:
# SharedFusedMoE returns tuple: (shared_out, fused_out)
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]

# Apply tensor parallel reduction if needed
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)

return final_hidden_states.view(orig_shape)


class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM):
"""
Qwen3 Omni MoE Talker language model.

This model extends Qwen3MoeLLMForCausalLM with:
- Shared expert support via SharedFusedMoE
- Codec embedding instead of text embedding
- No LM head (codec head is separate in the parent class)
Extends Qwen3MoeLLMForCausalLM (which already uses SharedFusedMoE with
shared-expert support) and replaces the text embedding / LM head with a
codec embedding so the talker operates over audio-codec tokens instead
of text tokens.
"""

def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str):
Expand Down Expand Up @@ -699,32 +545,6 @@ def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerCon
talker_config.text_config.hidden_size,
)

# Replace MoE blocks with shared expert versions
self._replace_moe_blocks_with_shared_expert(prefix)

def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None:
"""
Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock
that includes shared expert support via SharedFusedMoE.
"""
# Get compilation config to clean up registered layer names
compilation_config = self.talker_vllm_config.compilation_config

for layer_idx, layer in enumerate(self.model.layers):
# Check if this layer has a MoE block (has experts attribute)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
# Remove old layer registration from static_forward_context
old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts"
if old_experts_prefix in compilation_config.static_forward_context:
del compilation_config.static_forward_context[old_experts_prefix]

# Create new MoE block with shared expert support
layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock(
config=self.config,
quant_config=self.talker_vllm_config.quant_config,
prefix=f"{prefix}.model.layers.{layer_idx}.mlp",
)

def embed_input_ids(
self,
input_ids: torch.Tensor,
Expand Down