Skip to content

Commit f243fd1

Browse files
committed
fix logic
1 parent 83c655a commit f243fd1

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,26 +280,30 @@ def decode_forward(
280280
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
281281
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
282282

283-
if self.batch_size != cache_position.shape[0]:
283+
batch_size = inputs.shape[0]
284+
if batch_size != self.batch_size:
284285
raise RuntimeError(
285-
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
286+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
286287
)
287288

289+
if batch_size != cache_position.shape[0]:
290+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
291+
288292
if is_external_block_tables:
289293
if attention_mask is None:
290294
raise ValueError("attention_mask should be provided with external block tables.")
291295
if local_block_tables is None:
292296
raise ValueError("local_block_tables should be provided with external block tables.")
293-
294-
if self.rbln_config.use_local_attention:
295-
local_block_tables = (
296-
local_block_tables
297-
if local_block_tables is not None
298-
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
299-
)
297+
else:
298+
if self.rbln_config.use_local_attention:
299+
local_block_tables = (
300+
local_block_tables
301+
if local_block_tables is not None
302+
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
303+
)
300304

301305
if self.rbln_config.use_attention_mask and attention_mask is None:
302-
for b_idx in range(self.batch_size):
306+
for b_idx in range(batch_size):
303307
decoding_step = cache_position[b_idx].item()
304308
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
305309
raise ValueError(

0 commit comments

Comments
 (0)