@@ -78,11 +78,8 @@ def __init__(
7878 def forward (self , model_input : ModelInputForRBLN , ** kwargs ) -> torch .Tensor :
7979 input_ids = model_input .input_tokens
8080 block_tables = model_input .block_tables
81-
8281 request_nums = input_ids .shape [0 ]
83-
8482 is_prompt = model_input .is_prompt
85-
8683 valid_block_ids = block_tables .flatten ().to (torch .int32 )
8784
8885 if is_prompt :
@@ -93,10 +90,16 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
9390 input_features = audio_input ["input_features" ]
9491 if input_features is None :
9592 raise ValueError ("Whisper requires `input_features` as an input." )
96- # FIXME I think encoder should be called here
93+ _ = self .model .encoder (
94+ input_features = input_features ,
95+ block_tables = block_tables .squeeze (0 ).to (torch .int16 ),
96+ )
9797
9898 cache_position = torch .zeros (request_nums , 1 , dtype = torch .int32 )
9999
100+ # In whisper model,
101+ # decoder input is always required in prefill step,
102+ # so is_prompt=False is set for both prefill and decode step.
100103 kwargs = self .preprocess_for_decoder (
101104 is_prompt = False ,
102105 block_tables = block_tables ,
@@ -106,27 +109,21 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
106109 )
107110 decoder_cache_position = kwargs .pop ("cache_position" )
108111 decoder_block_tables = kwargs .pop ("block_tables" )
109- # FIXME Is it ok generate torch.zero tensor for each forward?
110- # OR just generate pooled tensor in the model instance?
111- # FIXME bucketing?
112+
113+ # Whisper model does not support bucketing.
112114 decoder_attention_mask = torch .zeros (
113115 self .batch_size , self .dec_max_seq_len , dtype = self .dtype
114116 )
115117 if is_prompt :
116- print ("block_tables" , block_tables )
117- _ = self .model .encoder (
118- input_features = input_features ,
119- block_tables = block_tables .squeeze (0 ).to (torch .int16 ),
120- )
121- for batch_idx in valid_block_ids :
122- decoder_cache_position [batch_idx ] = 0
123- decoder_attention_mask [batch_idx , 0 ] = 1
124- self .dec_lengths [batch_idx ] = 1
125118 decoder_input_ids = torch .full (
126119 (self .batch_size , 1 ),
127120 self .model .config .decoder_start_token_id ,
128121 dtype = torch .long ,
129122 )
123+ for batch_idx in valid_block_ids :
124+ decoder_cache_position [batch_idx ] = 0
125+ decoder_attention_mask [batch_idx , 0 ] = 1
126+ self .dec_lengths [batch_idx ] = 1
130127 decoder_output = self .model .decoder (
131128 decoder_input_ids = decoder_input_ids .contiguous (),
132129 decoder_attention_mask = decoder_attention_mask ,
0 commit comments