@@ -146,3 +146,51 @@ def embed_multimodal(self, **kwargs):
146146
147147 actual_call_kwargs = call_kwargs ["kwargs" ]["call_kwargs" ]
148148 assert "image_grid_thw" in actual_call_kwargs
149+
150+
151+ def test_wrap_embed_input_ids_func_for_qwen3vl ():
152+ """
153+ Test that wrap_embed_input_ids_func sets always_wrap_list to True
154+ for Qwen3VL architectures and wraps single tensors in a list.
155+ """
156+ mock_wrapper = MagicMock ()
157+ mock_wrapper .vllm_config .model_config .is_multimodal_model = True
158+
159+ # Set architecture to Qwen3VLForConditionalGeneration
160+ mock_wrapper .vllm_config .model_config .hf_config .architectures = ["Qwen3VLForConditionalGeneration" ]
161+
162+ class MockInnerModel :
163+ def embed_input_ids (self , input_ids , mm_embeds , ** kwargs ):
164+ pass
165+
166+ mock_inner = MockInnerModel ()
167+ mock_runner = MagicMock ()
168+ mock_runner .vllm_model = mock_inner
169+ mock_wrapper .model = mock_runner
170+
171+ embed_input_ids_func = VllmModelWrapper .wrap_embed_input_ids_func (mock_wrapper )
172+
173+ with patch ("torchax.default_env" ), \
174+ patch ("torch.func.functional_call" ) as mock_functional_call , \
175+ patch ("tpu_inference.models.vllm.vllm_model_wrapper.torch_view" ) as mock_torch_view :
176+
177+ mock_torch_view .side_effect = lambda x : x
178+
179+ params_and_buffers = {}
180+ input_ids = jax .numpy .zeros ((10 , ), dtype = jax .numpy .int32 )
181+ mm_embeds = jax .numpy .zeros ((10 , 512 ))
182+
183+ embed_input_ids_func (
184+ params_and_buffers ,
185+ input_ids ,
186+ mm_embeds ,
187+ is_multimodal = jax .numpy .zeros ((10 , ), dtype = jax .numpy .bool_ )
188+ )
189+
190+ mock_functional_call .assert_called_once ()
191+ _ , call_kwargs = mock_functional_call .call_args
192+
193+ actual_call_args = call_kwargs ["kwargs" ]["call_args" ]
194+ # Verify mm_embeds (second argument) was wrapped in a list
195+ assert isinstance (actual_call_args [1 ], list )
196+ assert len (actual_call_args [1 ]) == 1
0 commit comments