Skip to content

Add support for Qwen3-VL model via Torchax path#1974

Open
muskansh-google wants to merge 23 commits intovllm-project:mainfrom
muskansh-google:pr-1952
Open

Add support for Qwen3-VL model via Torchax path#1974
muskansh-google wants to merge 23 commits intovllm-project:mainfrom
muskansh-google:pr-1952

Conversation

@muskansh-google
Copy link
Copy Markdown
Collaborator

Description

This PR enables support for Qwen3-VL inference on uLLM using Torchax framework. By default, vLLM model vision encoders are treated as part of input preparation phase and not as part of core compiled model. Due to this limitation, it was extremely difficult to JIT compile the vision encoder and hence this PR performs eager execution of vision tower on TPUs for the time being.

Why this change is being made

Currently, uLLM does not support for Qwen3-VL model. The idea of using Torchax as the backend is to be able to levarage upstream vLLM model implementation. The Pytorch based Qwen3-VL uses dynamic operations and in-place state mutations (specifically for Deepstack features) that conflict with JAX's requirement for static trace graphs. This PR provides monkey-patches and utility wrappers to bridge the gap between vLLM's PyTorch implementation and JAX execution.

Solved Problems & Relevance

1. tpu_inference/models/vllm/vllm_model_wrapper.py

  • ViT Attention Optimization: Registered a custom function for torch_sdpa through Torchax to utilize sharded_flash_attention, significantly improving vision encoder performance on TPUs.
  • Deepstack Stateless Patching: Overrode _set_deepstack_input_embeds and _get_deepstack_input_embeds to use a stateless dictionary cache (_deepstack_tensors) instead of in-place model mutations.
  • JIT Side-Channel (State Passing): Packs intermediate Deepstack embeddings into inputs_embeds to pass them safely through JIT boundaries.
  • Dynamic Argument Wrapping: Added wrap_embed_multimodal_func and wrap_embed_input_ids_func to handle dynamic kwargs and shape conversions (mm_embeds had to be passed as a list of tensors for Qwen3-VL).

Shortcomings and Future Improvements

  • The inputs_embeds side-channel is a workaround for JAX/JIT signature limitations.

Tests

The changes were verified using the examples/multi_modal_inference.py script on a TPU v6e VM.

Reproduction Commands

Standard single image inference:

python3 -m examples.multi_modal_inference \
  --model Qwen/Qwen3-VL-8B-Instruct \
  --gpu-memory-utilization 0.9 \
  --enable-chunked-prefill False

yuyanpeng-google and others added 9 commits March 23, 2026 15:15
1. add vllm wrapper for multimodal.
2. modify interface of embed_input_ids with related jax model.
3. modify gather_mm_embeddings to get is_mm_embed for the new interface.
4. register function for torch.sdpa through torchax to use flash attention

Signed-off-by: Yuyan Peng <yuyanpeng@google.com>
from vllm.config import VllmConfig, set_current_vllm_config
from torchax.ops.ops_registry import (register_torch_dispatch_op,
register_torch_function_op)
from vllm.config import VllmConfig, set_current_vllm_config, set_current_vllm_config
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two set_current_vllm_config here.


model_example_map = {
"qwen2_5_vl": run_qwen2_5_vl,
"Qwen/Qwen2.5-VL-3B-Instruct": run_qwen_vl,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think specify Qwen2.5-VL-3B-Instruct might be not general enough in case we want to test Qwen2.5-VL-7B. Could you make this more general?

model_key = args.model


req_data = model_example_map[model_key](questions, modality, args)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a check to see if model_key is in the map and raise error if not?

else:
seg_ids = None

from tpu_inference.layers.common.attention_interface import \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already imported at the top of this file.

vllm_model._deepstack_tensors = {}

if isinstance(deepstack_input_embeds, dict):
vllm_model._deepstack_tensors.update(deepstack_input_embeds)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern here: do we need to reset this _deepstack_tesnors to {} if deepstack_input_embeds is None? Otherwise it may retain the values from the previous forward path.

)
# Qwen3-VL uses a different method signature and takes in mm_features as an argument.
import inspect
takes_mm_features = "mm_features" in inspect.signature(get_mrope_input_positions_fn).parameters
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use supports_kw here instead of inspect.signature?
Also, I think this can be moved outside the for loop: https://github.com/muskansh-google/tpu-inference/blob/dd49ffcca0d74b577ea88512bd45742ee04a67ea/tpu_inference/runner/persistent_batch_manager.py#L132.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants