Skip to content

Commit e2863c6

Browse files
Test changes in compilation manager and add unit test
1 parent 1bc02b6 commit e2863c6

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

examples/multi_modal_inference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@
88
on HuggingFace model repository.
99
1010
Example Command:
11-
python examples/multi_modal_inference.py \
11+
python3 examples/multi_modal_inference.py \
1212
--model Qwen/Qwen2.5-VL-3B-Instruct \
1313
--tensor-parallel-size 1 \
1414
--num-prompts 1
1515
1616
Example command to test multiple images
17-
python examples/multi_modal_inference.py \
17+
python3 examples/multi_modal_inference.py \
1818
--model Qwen/Qwen3-VL-8B-Instruct \
1919
--test-multi-image \
20-
--max-model-len 8192
20+
--max-model-len 8192 \
21+
--gpu-memory-utilization 0.9
2122
"""
2223

2324
from contextlib import contextmanager
@@ -184,7 +185,7 @@ def parse_args():
184185
parser.add_argument(
185186
"--gpu-memory-utilization",
186187
type=float,
187-
default=0.5,
188+
default=0.85,
188189
help="GPU memory utilization",
189190
)
190191

tests/models/vllm/test_vllm_model_wrapper_multimodal.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,51 @@ def embed_multimodal(self, **kwargs):
146146

147147
actual_call_kwargs = call_kwargs["kwargs"]["call_kwargs"]
148148
assert "image_grid_thw" in actual_call_kwargs
149+
150+
151+
def test_wrap_embed_input_ids_func_for_qwen3vl():
152+
"""
153+
Test that wrap_embed_input_ids_func sets always_wrap_list to True
154+
for Qwen3VL architectures and wraps single tensors in a list.
155+
"""
156+
mock_wrapper = MagicMock()
157+
mock_wrapper.vllm_config.model_config.is_multimodal_model = True
158+
159+
# Set architecture to Qwen3VLForConditionalGeneration
160+
mock_wrapper.vllm_config.model_config.hf_config.architectures = ["Qwen3VLForConditionalGeneration"]
161+
162+
class MockInnerModel:
163+
def embed_input_ids(self, input_ids, mm_embeds, **kwargs):
164+
pass
165+
166+
mock_inner = MockInnerModel()
167+
mock_runner = MagicMock()
168+
mock_runner.vllm_model = mock_inner
169+
mock_wrapper.model = mock_runner
170+
171+
embed_input_ids_func = VllmModelWrapper.wrap_embed_input_ids_func(mock_wrapper)
172+
173+
with patch("torchax.default_env"), \
174+
patch("torch.func.functional_call") as mock_functional_call, \
175+
patch("tpu_inference.models.vllm.vllm_model_wrapper.torch_view") as mock_torch_view:
176+
177+
mock_torch_view.side_effect = lambda x: x
178+
179+
params_and_buffers = {}
180+
input_ids = jax.numpy.zeros((10, ), dtype=jax.numpy.int32)
181+
mm_embeds = jax.numpy.zeros((10, 512))
182+
183+
embed_input_ids_func(
184+
params_and_buffers,
185+
input_ids,
186+
mm_embeds,
187+
is_multimodal=jax.numpy.zeros((10, ), dtype=jax.numpy.bool_)
188+
)
189+
190+
mock_functional_call.assert_called_once()
191+
_, call_kwargs = mock_functional_call.call_args
192+
193+
actual_call_args = call_kwargs["kwargs"]["call_args"]
194+
# Verify mm_embeds (second argument) was wrapped in a list
195+
assert isinstance(actual_call_args[1], list)
196+
assert len(actual_call_args[1]) == 1

0 commit comments

Comments
 (0)