Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 40 additions & 36 deletions vllm_rbln/model_executor/models/optimum/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class RBLNOptimumWhisperForConditionalGeneration(
SupportsTranscription,
SupportsMultiModal,
):
INVALID_TOKEN = 100
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
supports_segment_timestamp = True
Expand Down Expand Up @@ -79,11 +78,8 @@ def __init__(
def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
input_ids = model_input.input_tokens
block_tables = model_input.block_tables

request_nums = input_ids.shape[0]

is_prompt = model_input.is_prompt

valid_block_ids = block_tables.flatten().to(torch.int32)

if is_prompt:
Expand All @@ -94,58 +90,66 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
input_features = audio_input["input_features"]
if input_features is None:
raise ValueError("Whisper requires `input_features` as an input.")
# FIXME I think encoder should be called here
_ = self.model.encoder(
input_features=input_features,
block_tables=block_tables.squeeze(0).to(torch.int16),
)

cache_position = torch.zeros(request_nums, 1, dtype=torch.int32)

# In whisper model,
# decoder input is always required in prefill step,
# so is_prompt=False is set for both prefill and decode step.
kwargs = self.preprocess_for_decoder(
is_prompt,
block_tables,
input_ids,
cache_position,
is_prompt=False,
block_tables=block_tables,
input_ids=input_ids,
cache_position=cache_position,
input_block_ids=valid_block_ids,
)
input_ids = kwargs.pop("input_ids")
cache_position = kwargs.pop("cache_position")
block_tables = kwargs.pop("block_tables")
decoder_cache_position = kwargs.pop("cache_position")
decoder_block_tables = kwargs.pop("block_tables")

# Whisper model does not support bucketing.
decoder_attention_mask = torch.zeros(
self.batch_size, self.dec_max_seq_len, dtype=self.dtype
)
if is_prompt:
_ = self.model.encoder(
input_features=input_features, block_tables=block_tables
decoder_input_ids = torch.full(
(self.batch_size, 1),
self.model.config.decoder_start_token_id,
dtype=torch.long,
)
lm_logits = torch.zeros(
1, 1, self.model.config.vocab_size + self.INVALID_TOKEN
for batch_idx in valid_block_ids:
decoder_cache_position[batch_idx] = 0
decoder_attention_mask[batch_idx, 0] = 1
self.dec_lengths[batch_idx] = 1
decoder_output = self.model.decoder(
decoder_input_ids=decoder_input_ids.contiguous(),
decoder_attention_mask=decoder_attention_mask,
cache_position=decoder_cache_position,
block_tables=decoder_block_tables,
)
# Set the probability of INVALID_TOKEN (the last token in
# the logits tensor) to 1.0.
lm_logits[0][0][-1] = 1
self.dec_lengths[valid_block_ids[0].item()] = 0

else:
input_ids[
input_ids == (self.model.config.vocab_size + self.INVALID_TOKEN - 1)
] = self.model.config.decoder_start_token_id

# FIXME Is it ok generate torch.zero tensor for each forward?
# OR just generate pooled tensor in the model instance?
decoder_attention_mask = torch.zeros(
self.batch_size, self.dec_max_seq_len, dtype=self.dtype
)
decoder_input_ids = kwargs.pop("input_ids")
# Generate cache_position using dec_lengths
for batch_idx in valid_block_ids:
cache_position[batch_idx] = self.dec_lengths[batch_idx]
decoder_attention_mask[batch_idx, : cache_position[batch_idx] + 1] = 1
decoder_cache_position[batch_idx] = self.dec_lengths[batch_idx]
decoder_attention_mask[
batch_idx, : decoder_cache_position[batch_idx] + 1
] = 1
self.dec_lengths[batch_idx] += 1

decoder_output = self.model.decoder(
decoder_input_ids=input_ids.contiguous(),
decoder_input_ids=decoder_input_ids.contiguous(),
decoder_attention_mask=decoder_attention_mask,
cache_position=cache_position,
block_tables=block_tables,
cache_position=decoder_cache_position,
block_tables=decoder_block_tables,
)

lm_logits = decoder_output.logits
lm_logits = lm_logits[valid_block_ids]
lm_logits = decoder_output.logits
lm_logits = lm_logits[valid_block_ids]
return lm_logits

def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
Expand Down
Loading