Skip to content

Commit 440a40e

Browse files
fix: update batch attention logic to handle padding and size conditions (#436)
1 parent 7f142ce commit 440a40e

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

vllm_rbln/v1/attention/backends/flash_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ def build(
12031203
query_start_loc=query_start_loc,
12041204
max_seq_len=query_max_seq_len,
12051205
seq_lens=seq_lens_tensor.to(self.device)
1206-
if not self.is_batch_attention_opt or is_prefills[0]
1206+
if not self.is_batch_attention_opt or is_prefills[0] or batch_pad <= 1
12071207
else seq_idx.to(self.device),
12081208
block_tables=block_tables_tensor.to(self.device),
12091209
slot_mapping=slot_mapping,
@@ -1437,15 +1437,15 @@ def forward(
14371437
value,
14381438
kv_cache,
14391439
attn_metadata.cache_seq_lens.to(torch.int32)
1440-
if self.is_batch_attention_opt
1440+
if self.is_batch_attention_opt and b_size > 1
14411441
else attn_metadata.cache_seq_lens,
14421442
attn_metadata.cache_offsets,
14431443
self.scale,
14441444
attn_metadata.local_block_tables,
14451445
self.scale, # dummy
14461446
]
14471447
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1448-
if self.is_batch_attention_opt:
1448+
if self.is_batch_attention_opt and b_size > 1:
14491449
decode_args.append(attn_metadata.swa_attn_masks)
14501450
else:
14511451
decode_args.append(None)

0 commit comments

Comments
 (0)