Skip to content

Commit 42ad5b2

Browse files
committed
fix: run both encoder and decoder in prefill step
1 parent ec6da31 commit 42ad5b2

1 file changed

Lines changed: 21 additions & 13 deletions

File tree

vllm_rbln/model_executor/models/optimum/whisper.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class RBLNOptimumWhisperForConditionalGeneration(
3737
SupportsTranscription,
3838
SupportsMultiModal,
3939
):
40-
INVALID_TOKEN = 100
4140
# Whisper only supports audio-conditioned generation.
4241
supports_transcription_only = True
4342
supports_segment_timestamp = True
@@ -113,19 +112,28 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
113112
_ = self.model.encoder(
114113
input_features=input_features, block_tables=block_tables
115114
)
116-
lm_logits = torch.zeros(
117-
1, 1, self.model.config.vocab_size + self.INVALID_TOKEN
115+
116+
decoder_input_ids = torch.full(
117+
(request_nums, 1),
118+
self.model.config.decoder_start_token_id,
119+
dtype=torch.long,
120+
)
121+
decoder_attention_mask = torch.zeros(
122+
self.batch_size, self.dec_max_seq_len, dtype=self.dtype
118123
)
119-
# Set the probability of INVALID_TOKEN (the last token in
120-
# the logits tensor) to 1.0.
121-
lm_logits[0][0][-1] = 1
122-
self.dec_lengths[valid_block_ids[0].item()] = 0
124+
for batch_idx in valid_block_ids:
125+
cache_position[batch_idx] = 0
126+
decoder_attention_mask[batch_idx, 0] = 1
127+
self.dec_lengths[batch_idx] = 1
123128

124-
else:
125-
input_ids[
126-
input_ids == (self.model.config.vocab_size + self.INVALID_TOKEN - 1)
127-
] = self.model.config.decoder_start_token_id
129+
decoder_output = self.model.decoder(
130+
decoder_input_ids=decoder_input_ids.contiguous(),
131+
decoder_attention_mask=decoder_attention_mask,
132+
cache_position=cache_position,
133+
block_tables=block_tables.unsqueeze(-1),
134+
)
128135

136+
else:
129137
# FIXME Is it ok generate torch.zero tensor for each forward?
130138
# OR just generate pooled tensor in the model instance?
131139
decoder_attention_mask = torch.zeros(
@@ -144,8 +152,8 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
144152
block_tables=block_tables,
145153
)
146154

147-
lm_logits = decoder_output.logits
148-
lm_logits = lm_logits[valid_block_ids]
155+
lm_logits = decoder_output.logits
156+
lm_logits = lm_logits[valid_block_ids]
149157
return lm_logits
150158

151159
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:

0 commit comments

Comments
 (0)