Skip to content

Commit a5445a2

Browse files
committed
refactor whisper model code
1 parent e969871 commit a5445a2

1 file changed

Lines changed: 13 additions & 16 deletions

File tree

vllm_rbln/model_executor/models/optimum/whisper.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,8 @@ def __init__(
7878
def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
7979
input_ids = model_input.input_tokens
8080
block_tables = model_input.block_tables
81-
8281
request_nums = input_ids.shape[0]
83-
8482
is_prompt = model_input.is_prompt
85-
8683
valid_block_ids = block_tables.flatten().to(torch.int32)
8784

8885
if is_prompt:
@@ -93,10 +90,16 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
9390
input_features = audio_input["input_features"]
9491
if input_features is None:
9592
raise ValueError("Whisper requires `input_features` as an input.")
96-
# FIXME I think encoder should be called here
93+
_ = self.model.encoder(
94+
input_features=input_features,
95+
block_tables=block_tables.squeeze(0).to(torch.int16),
96+
)
9797

9898
cache_position = torch.zeros(request_nums, 1, dtype=torch.int32)
9999

100+
# In whisper model,
101+
# decoder input is always required in prefill step,
102+
# so is_prompt=False is set for both prefill and decode step.
100103
kwargs = self.preprocess_for_decoder(
101104
is_prompt=False,
102105
block_tables=block_tables,
@@ -106,27 +109,21 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
106109
)
107110
decoder_cache_position = kwargs.pop("cache_position")
108111
decoder_block_tables = kwargs.pop("block_tables")
109-
# FIXME Is it ok generate torch.zero tensor for each forward?
110-
# OR just generate pooled tensor in the model instance?
111-
# FIXME bucketing?
112+
113+
# Whisper model does not support bucketing.
112114
decoder_attention_mask = torch.zeros(
113115
self.batch_size, self.dec_max_seq_len, dtype=self.dtype
114116
)
115117
if is_prompt:
116-
print("block_tables", block_tables)
117-
_ = self.model.encoder(
118-
input_features=input_features,
119-
block_tables=block_tables.squeeze(0).to(torch.int16),
120-
)
121-
for batch_idx in valid_block_ids:
122-
decoder_cache_position[batch_idx] = 0
123-
decoder_attention_mask[batch_idx, 0] = 1
124-
self.dec_lengths[batch_idx] = 1
125118
decoder_input_ids = torch.full(
126119
(self.batch_size, 1),
127120
self.model.config.decoder_start_token_id,
128121
dtype=torch.long,
129122
)
123+
for batch_idx in valid_block_ids:
124+
decoder_cache_position[batch_idx] = 0
125+
decoder_attention_mask[batch_idx, 0] = 1
126+
self.dec_lengths[batch_idx] = 1
130127
decoder_output = self.model.decoder(
131128
decoder_input_ids=decoder_input_ids.contiguous(),
132129
decoder_attention_mask=decoder_attention_mask,

0 commit comments

Comments
 (0)