Skip to content

Commit e969871

Browse files
committed
fix decoder_start_token_id
1 parent 1e947e7 commit e969871

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

vllm_rbln/model_executor/models/optimum/whisper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
104104
cache_position=cache_position,
105105
input_block_ids=valid_block_ids,
106106
)
107-
decoder_input_ids = kwargs.pop("input_ids")
108107
decoder_cache_position = kwargs.pop("cache_position")
109108
decoder_block_tables = kwargs.pop("block_tables")
110109
# FIXME Is it ok generate torch.zero tensor for each forward?
@@ -123,7 +122,11 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
123122
decoder_cache_position[batch_idx] = 0
124123
decoder_attention_mask[batch_idx, 0] = 1
125124
self.dec_lengths[batch_idx] = 1
126-
125+
decoder_input_ids = torch.full(
126+
(self.batch_size, 1),
127+
self.model.config.decoder_start_token_id,
128+
dtype=torch.long,
129+
)
127130
decoder_output = self.model.decoder(
128131
decoder_input_ids=decoder_input_ids.contiguous(),
129132
decoder_attention_mask=decoder_attention_mask,
@@ -132,6 +135,7 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
132135
)
133136

134137
else:
138+
decoder_input_ids = kwargs.pop("input_ids")
135139
# Generate cache_position using dec_lengths
136140
for batch_idx in valid_block_ids:
137141
decoder_cache_position[batch_idx] = self.dec_lengths[batch_idx]

0 commit comments

Comments
 (0)