Skip to content

Commit 1e947e7

Browse files
committed
fix: shape
1 parent b35b22f commit 1e947e7

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

vllm_rbln/model_executor/models/optimum/whisper.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,18 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
9898
cache_position = torch.zeros(request_nums, 1, dtype=torch.int32)
9999

100100
kwargs = self.preprocess_for_decoder(
101-
is_prompt=is_prompt,
101+
is_prompt=False,
102102
block_tables=block_tables,
103103
input_ids=input_ids,
104104
cache_position=cache_position,
105105
input_block_ids=valid_block_ids,
106106
)
107-
input_ids = kwargs.pop("input_ids")
108-
cache_position = kwargs.pop("cache_position")
107+
decoder_input_ids = kwargs.pop("input_ids")
108+
decoder_cache_position = kwargs.pop("cache_position")
109109
decoder_block_tables = kwargs.pop("block_tables")
110110
# FIXME Is it ok generate torch.zero tensor for each forward?
111111
# OR just generate pooled tensor in the model instance?
112+
# FIXME bucketing?
112113
decoder_attention_mask = torch.zeros(
113114
self.batch_size, self.dec_max_seq_len, dtype=self.dtype
114115
)
@@ -119,28 +120,30 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
119120
block_tables=block_tables.squeeze(0).to(torch.int16),
120121
)
121122
for batch_idx in valid_block_ids:
122-
cache_position[batch_idx] = 0
123+
decoder_cache_position[batch_idx] = 0
123124
decoder_attention_mask[batch_idx, 0] = 1
124125
self.dec_lengths[batch_idx] = 1
125126

126127
decoder_output = self.model.decoder(
127-
decoder_input_ids=input_ids.contiguous(),
128+
decoder_input_ids=decoder_input_ids.contiguous(),
128129
decoder_attention_mask=decoder_attention_mask,
129-
cache_position=cache_position,
130+
cache_position=decoder_cache_position,
130131
block_tables=decoder_block_tables,
131132
)
132133

133134
else:
134135
# Generate cache_position using dec_lengths
135136
for batch_idx in valid_block_ids:
136-
cache_position[batch_idx] = self.dec_lengths[batch_idx]
137-
decoder_attention_mask[batch_idx, : cache_position[batch_idx] + 1] = 1
137+
decoder_cache_position[batch_idx] = self.dec_lengths[batch_idx]
138+
decoder_attention_mask[
139+
batch_idx, : decoder_cache_position[batch_idx] + 1
140+
] = 1
138141
self.dec_lengths[batch_idx] += 1
139142

140143
decoder_output = self.model.decoder(
141-
decoder_input_ids=input_ids.contiguous(),
144+
decoder_input_ids=decoder_input_ids.contiguous(),
142145
decoder_attention_mask=decoder_attention_mask,
143-
cache_position=cache_position,
146+
cache_position=decoder_cache_position,
144147
block_tables=decoder_block_tables,
145148
)
146149

0 commit comments

Comments
 (0)