Skip to content

Commit b41fe2c

Browse files
Fix mm embedding to be 3d tensor
1 parent e2863c6 commit b41fe2c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _precompile_input_embeddings_merger(self) -> None:
129129
)
130130
sharding = NamedSharding(self.runner.mesh, PartitionSpec())
131131
dummy_multimodal_embeddings = self._create_dummy_tensor(
132-
(num_tokens, hidden_size),
132+
(1, num_tokens, hidden_size),
133133
self.runner.vllm_config.model_config.dtype,
134134
sharding=sharding)
135135
dummy_input_ids = self._create_dummy_tensor((num_tokens, ),

0 commit comments

Comments
 (0)