@@ -37,7 +37,6 @@ class RBLNOptimumWhisperForConditionalGeneration(
3737 SupportsTranscription ,
3838 SupportsMultiModal ,
3939):
40- INVALID_TOKEN = 100
4140 # Whisper only supports audio-conditioned generation.
4241 supports_transcription_only = True
4342 supports_segment_timestamp = True
@@ -113,19 +112,28 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
113112 _ = self .model .encoder (
114113 input_features = input_features , block_tables = block_tables
115114 )
116- lm_logits = torch .zeros (
117- 1 , 1 , self .model .config .vocab_size + self .INVALID_TOKEN
115+
116+ decoder_input_ids = torch .full (
117+ (request_nums , 1 ),
118+ self .model .config .decoder_start_token_id ,
119+ dtype = torch .long ,
120+ )
121+ decoder_attention_mask = torch .zeros (
122+ self .batch_size , self .dec_max_seq_len , dtype = self .dtype
118123 )
119- # Set the probability of INVALID_TOKEN (the last token in
120- # the logits tensor) to 1.0.
121- lm_logits [ 0 ][ 0 ][ - 1 ] = 1
122- self .dec_lengths [valid_block_ids [ 0 ]. item ()] = 0
124+ for batch_idx in valid_block_ids :
125+ cache_position [ batch_idx ] = 0
126+ decoder_attention_mask [ batch_idx , 0 ] = 1
127+ self .dec_lengths [batch_idx ] = 1
123128
124- else :
125- input_ids [
126- input_ids == (self .model .config .vocab_size + self .INVALID_TOKEN - 1 )
127- ] = self .model .config .decoder_start_token_id
129+ decoder_output = self .model .decoder (
130+ decoder_input_ids = decoder_input_ids .contiguous (),
131+ decoder_attention_mask = decoder_attention_mask ,
132+ cache_position = cache_position ,
133+ block_tables = block_tables .unsqueeze (- 1 ),
134+ )
128135
136+ else :
129137 # FIXME Is it ok generate torch.zero tensor for each forward?
130138 # OR just generate pooled tensor in the model instance?
131139 decoder_attention_mask = torch .zeros (
@@ -144,8 +152,8 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
144152 block_tables = block_tables ,
145153 )
146154
147- lm_logits = decoder_output .logits
148- lm_logits = lm_logits [valid_block_ids ]
155+ lm_logits = decoder_output .logits
156+ lm_logits = lm_logits [valid_block_ids ]
149157 return lm_logits
150158
151159 def _parse_and_validate_audio_input (self , ** kwargs : object ) -> WhisperAudioInputs :
0 commit comments