Skip to content

Commit 6546c11

Browse files
authored
fix: reshape logits only in prefill phase (#409)
1 parent 3fd5260 commit 6546c11

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

vllm_rbln/v1/worker/optimum_model_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,8 +1295,15 @@ def sample_tokens(
12951295
num_reqs = self.input_batch.num_reqs
12961296
padded_logits = self.pooled_tensors[self.bucket_size]
12971297
padded_logits[:num_reqs].copy_(logits)
1298-
else:
1298+
elif is_prompt:
1299+
# Among self.input_batch.num_reqs > 1 cases,
1300+
# only the prefill stage of multimodal models produces logits
1301+
# with varying strides during the prefill stage.
1302+
# To avoid frequent recompilations caused by these stride variations,
1303+
# we flatten the logits into a 2D tensor with shape (1, -1).
12991304
padded_logits = logits.reshape(1, -1)
1305+
else:
1306+
padded_logits = logits
13001307
sampler_output = self._sample(padded_logits, spec_decode_metadata=None)
13011308
self.input_batch.prev_sampled_token_ids = None
13021309

0 commit comments

Comments
 (0)