@@ -98,17 +98,18 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
9898 cache_position = torch .zeros (request_nums , 1 , dtype = torch .int32 )
9999
100100 kwargs = self .preprocess_for_decoder (
101- is_prompt = is_prompt ,
101+ is_prompt = False ,
102102 block_tables = block_tables ,
103103 input_ids = input_ids ,
104104 cache_position = cache_position ,
105105 input_block_ids = valid_block_ids ,
106106 )
107- input_ids = kwargs .pop ("input_ids" )
108- cache_position = kwargs .pop ("cache_position" )
107+ decoder_input_ids = kwargs .pop ("input_ids" )
108+ decoder_cache_position = kwargs .pop ("cache_position" )
109109 decoder_block_tables = kwargs .pop ("block_tables" )
110110 # FIXME Is it ok generate torch.zero tensor for each forward?
111111 # OR just generate pooled tensor in the model instance?
112+ # FIXME bucketing?
112113 decoder_attention_mask = torch .zeros (
113114 self .batch_size , self .dec_max_seq_len , dtype = self .dtype
114115 )
@@ -119,28 +120,30 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
119120 block_tables = block_tables .squeeze (0 ).to (torch .int16 ),
120121 )
121122 for batch_idx in valid_block_ids :
122- cache_position [batch_idx ] = 0
123+ decoder_cache_position [batch_idx ] = 0
123124 decoder_attention_mask [batch_idx , 0 ] = 1
124125 self .dec_lengths [batch_idx ] = 1
125126
126127 decoder_output = self .model .decoder (
127- decoder_input_ids = input_ids .contiguous (),
128+ decoder_input_ids = decoder_input_ids .contiguous (),
128129 decoder_attention_mask = decoder_attention_mask ,
129- cache_position = cache_position ,
130+ cache_position = decoder_cache_position ,
130131 block_tables = decoder_block_tables ,
131132 )
132133
133134 else :
134135 # Generate cache_position using dec_lengths
135136 for batch_idx in valid_block_ids :
136- cache_position [batch_idx ] = self .dec_lengths [batch_idx ]
137- decoder_attention_mask [batch_idx , : cache_position [batch_idx ] + 1 ] = 1
137+ decoder_cache_position [batch_idx ] = self .dec_lengths [batch_idx ]
138+ decoder_attention_mask [
139+ batch_idx , : decoder_cache_position [batch_idx ] + 1
140+ ] = 1
138141 self .dec_lengths [batch_idx ] += 1
139142
140143 decoder_output = self .model .decoder (
141- decoder_input_ids = input_ids .contiguous (),
144+ decoder_input_ids = decoder_input_ids .contiguous (),
142145 decoder_attention_mask = decoder_attention_mask ,
143- cache_position = cache_position ,
146+ cache_position = decoder_cache_position ,
144147 block_tables = decoder_block_tables ,
145148 )
146149
0 commit comments