We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3fd5260 commit 6546c11Copy full SHA for 6546c11
1 file changed
vllm_rbln/v1/worker/optimum_model_runner.py
@@ -1295,8 +1295,15 @@ def sample_tokens(
1295
num_reqs = self.input_batch.num_reqs
1296
padded_logits = self.pooled_tensors[self.bucket_size]
1297
padded_logits[:num_reqs].copy_(logits)
1298
- else:
+ elif is_prompt:
1299
+ # Among self.input_batch.num_reqs > 1 cases,
1300
+ # only the prefill stage of multimodal models produces logits
1301
+ # with varying strides during the prefill stage.
1302
+ # To avoid frequent recompilations caused by these stride variations,
1303
+ # we flatten the logits into a 2D tensor with shape (1, -1).
1304
padded_logits = logits.reshape(1, -1)
1305
+ else:
1306
+ padded_logits = logits
1307
sampler_output = self._sample(padded_logits, spec_decode_metadata=None)
1308
self.input_batch.prev_sampled_token_ids = None
1309
0 commit comments