@@ -280,26 +280,30 @@ def decode_forward(
280280 if lora_int_ids is not None and lora_int_ids .shape [0 ] != self .batch_size :
281281 raise ValueError (f"lora_int_ids size mismatch: got { lora_int_ids .shape [0 ]} , expected { self .batch_size } ." )
282282
283- if self .batch_size != cache_position .shape [0 ]:
283+ batch_size = inputs .shape [0 ]
284+ if batch_size != self .batch_size :
284285 raise RuntimeError (
285- f"Cache position size mismatch: got { cache_position . shape [ 0 ] } , expected { self .batch_size } ."
286+ f"Batch size mismatch: got { batch_size } , expected { self .batch_size } (compiled batch size) ."
286287 )
287288
289+ if batch_size != cache_position .shape [0 ]:
290+ raise RuntimeError (f"Cache position size mismatch: got { cache_position .shape [0 ]} , expected { batch_size } ." )
291+
288292 if is_external_block_tables :
289293 if attention_mask is None :
290294 raise ValueError ("attention_mask should be provided with external block tables." )
291295 if local_block_tables is None :
292296 raise ValueError ("local_block_tables should be provided with external block tables." )
293-
294- if self .rbln_config .use_local_attention :
295- local_block_tables = (
296- local_block_tables
297- if local_block_tables is not None
298- else torch .arange (0 , self . batch_size , dtype = torch .int16 ).view (self . batch_size , - 1 )
299- )
297+ else :
298+ if self .rbln_config .use_local_attention :
299+ local_block_tables = (
300+ local_block_tables
301+ if local_block_tables is not None
302+ else torch .arange (0 , batch_size , dtype = torch .int16 ).view (batch_size , - 1 )
303+ )
300304
301305 if self .rbln_config .use_attention_mask and attention_mask is None :
302- for b_idx in range (self . batch_size ):
306+ for b_idx in range (batch_size ):
303307 decoding_step = cache_position [b_idx ].item ()
304308 if not (0 <= decoding_step < self .dec_attn_mask .shape [- 1 ]):
305309 raise ValueError (
0 commit comments