Add support for Qwen3-VL model via Torchax path#1974
Add support for Qwen3-VL model via Torchax path#1974muskansh-google wants to merge 23 commits intovllm-project:mainfrom
Conversation
1. add vllm wrapper for multimodal. 2. modify interface of embed_input_ids with related jax model. 3. modify gather_mm_embeddings to get is_mm_embed for the new interface. 4. register function for torch.sdpa through torchax to use flash attention Signed-off-by: Yuyan Peng <yuyanpeng@google.com>
c564119 to
be6178e
Compare
b41fe2c to
afd05c7
Compare
| from vllm.config import VllmConfig, set_current_vllm_config | ||
| from torchax.ops.ops_registry import (register_torch_dispatch_op, | ||
| register_torch_function_op) | ||
| from vllm.config import VllmConfig, set_current_vllm_config, set_current_vllm_config |
There was a problem hiding this comment.
There are two set_current_vllm_config here.
|
|
||
| model_example_map = { | ||
| "qwen2_5_vl": run_qwen2_5_vl, | ||
| "Qwen/Qwen2.5-VL-3B-Instruct": run_qwen_vl, |
There was a problem hiding this comment.
I think specify Qwen2.5-VL-3B-Instruct might be not general enough in case we want to test Qwen2.5-VL-7B. Could you make this more general?
| model_key = args.model | ||
|
|
||
|
|
||
| req_data = model_example_map[model_key](questions, modality, args) |
There was a problem hiding this comment.
Could you add a check to see if model_key is in the map and raise error if not?
| else: | ||
| seg_ids = None | ||
|
|
||
| from tpu_inference.layers.common.attention_interface import \ |
There was a problem hiding this comment.
This is already imported at the top of this file.
| vllm_model._deepstack_tensors = {} | ||
|
|
||
| if isinstance(deepstack_input_embeds, dict): | ||
| vllm_model._deepstack_tensors.update(deepstack_input_embeds) |
There was a problem hiding this comment.
One concern here: do we need to reset this _deepstack_tesnors to {} if deepstack_input_embeds is None? Otherwise it may retain the values from the previous forward path.
| ) | ||
| # Qwen3-VL uses a different method signature and takes in mm_features as an argument. | ||
| import inspect | ||
| takes_mm_features = "mm_features" in inspect.signature(get_mrope_input_positions_fn).parameters |
There was a problem hiding this comment.
Could you use supports_kw here instead of inspect.signature?
Also, I think this can be moved outside the for loop: https://github.com/muskansh-google/tpu-inference/blob/dd49ffcca0d74b577ea88512bd45742ee04a67ea/tpu_inference/runner/persistent_batch_manager.py#L132.
Description
This PR enables support for Qwen3-VL inference on uLLM using Torchax framework. By default, vLLM model vision encoders are treated as part of input preparation phase and not as part of core compiled model. Due to this limitation, it was extremely difficult to JIT compile the vision encoder and hence this PR performs eager execution of vision tower on TPUs for the time being.
Why this change is being made
Currently, uLLM does not support for Qwen3-VL model. The idea of using Torchax as the backend is to be able to levarage upstream vLLM model implementation. The Pytorch based Qwen3-VL uses dynamic operations and in-place state mutations (specifically for Deepstack features) that conflict with JAX's requirement for static trace graphs. This PR provides monkey-patches and utility wrappers to bridge the gap between vLLM's PyTorch implementation and JAX execution.
Solved Problems & Relevance
1.
tpu_inference/models/vllm/vllm_model_wrapper.pytorch_sdpathrough Torchax to utilizesharded_flash_attention, significantly improving vision encoder performance on TPUs._set_deepstack_input_embedsand_get_deepstack_input_embedsto use a stateless dictionary cache (_deepstack_tensors) instead of in-place model mutations.inputs_embedsto pass them safely through JIT boundaries.wrap_embed_multimodal_funcandwrap_embed_input_ids_functo handle dynamic kwargs and shape conversions (mm_embedshad to be passed as a list of tensors for Qwen3-VL).Shortcomings and Future Improvements
inputs_embedsside-channel is a workaround for JAX/JIT signature limitations.Tests
The changes were verified using the
examples/multi_modal_inference.pyscript on a TPU v6e VM.Reproduction Commands
Standard single image inference: