@@ -1474,7 +1474,7 @@ def forward(
14741474 )
14751475
14761476 if q_len == 1 :
1477- attn_output = causal_attention_naive_decode ( # noqa: E501
1477+ decode_args = [
14781478 query ,
14791479 key ,
14801480 value ,
@@ -1483,9 +1483,14 @@ def forward(
14831483 self .scale ,
14841484 attn_metadata .block_tables .to (torch .int16 ),
14851485 self .scale , # dummy (required by rbln_triton_ops signature)
1486+ ]
1487+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1488+ decode_args .append (self .sinks )
1489+ attn_output = causal_attention_naive_decode ( # noqa: E501
1490+ * decode_args ,
14861491 )
14871492 else :
1488- attn_output = causal_attention_naive_prefill ( # noqa: E501
1493+ prefill_args = [
14891494 query ,
14901495 key ,
14911496 value ,
@@ -1494,6 +1499,11 @@ def forward(
14941499 self .scale ,
14951500 attn_metadata .block_tables .to (torch .int16 ),
14961501 self .scale , # dummy (required by rbln_triton_ops signature)
1502+ ]
1503+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1504+ prefill_args .append (self .sinks )
1505+ attn_output = causal_attention_naive_prefill ( # noqa: E501
1506+ * prefill_args ,
14971507 )
14981508 else :
14991509 if envs .VLLM_RBLN_COMPILE_MODEL :
@@ -1524,7 +1534,7 @@ def forward(
15241534 # * otherwise - seq_lens[B, P] == dyn_size_for_partitions,
15251535 # dynamic size for each partition
15261536 if q_len == 1 :
1527- attn_output = flash_causal_attention_naive_decode ( # noqa: E501
1537+ decode_args = [
15281538 query ,
15291539 key ,
15301540 value ,
@@ -1533,9 +1543,14 @@ def forward(
15331543 attn_metadata .seq_lens .to (torch .int16 ),
15341544 attn_metadata .block_tables .to (torch .int16 ),
15351545 self .scale , # dummy
1546+ ]
1547+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1548+ decode_args .append (self .sinks )
1549+ attn_output = flash_causal_attention_naive_decode ( # noqa: E501
1550+ * decode_args ,
15361551 )
15371552 else :
1538- attn_output = flash_causal_attention_naive_prefill ( # noqa: E501
1553+ prefill_args = [
15391554 query ,
15401555 key ,
15411556 value ,
@@ -1544,6 +1559,11 @@ def forward(
15441559 attn_metadata .seq_lens .to (torch .int16 ),
15451560 attn_metadata .block_tables .to (torch .int16 ),
15461561 self .scale , # dummy
1562+ ]
1563+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1564+ prefill_args .append (self .sinks )
1565+ attn_output = flash_causal_attention_naive_prefill ( # noqa: E501
1566+ * prefill_args ,
15471567 )
15481568 else :
15491569 if self .is_normal :
@@ -1568,7 +1588,7 @@ def forward(
15681588 )
15691589
15701590 if q_len == 1 :
1571- attn_output = attention_naive_decode ( # noqa: E501
1591+ decode_args = [
15721592 query ,
15731593 key ,
15741594 value ,
@@ -1578,9 +1598,14 @@ def forward(
15781598 self .scale ,
15791599 attn_metadata .block_tables .to (torch .int16 ),
15801600 self .scale , # dummy (required by rbln_triton_ops signature)
1601+ ]
1602+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1603+ decode_args .append (self .sinks )
1604+ attn_output = attention_naive_decode ( # noqa: E501
1605+ * decode_args ,
15811606 )
15821607 else :
1583- attn_output = attention_naive_prefill ( # noqa: E501
1608+ prefill_args = [
15841609 query ,
15851610 key ,
15861611 value ,
@@ -1590,6 +1615,11 @@ def forward(
15901615 self .scale ,
15911616 attn_metadata .block_tables .to (torch .int16 ),
15921617 self .scale , # dummy (required by rbln_triton_ops signature)
1618+ ]
1619+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1620+ prefill_args .append (self .sinks )
1621+ attn_output = attention_naive_prefill ( # noqa: E501
1622+ * prefill_args ,
15931623 )
15941624 else :
15951625 if envs .VLLM_RBLN_COMPILE_MODEL :
@@ -1612,7 +1642,7 @@ def forward(
16121642 flash_attention_naive_decode = flash_attention_naive_decode_impl
16131643
16141644 if q_len == 1 :
1615- attn_output = flash_attention_naive_decode ( # noqa: E501
1645+ decode_args = [
16161646 query ,
16171647 key ,
16181648 value ,
@@ -1622,9 +1652,14 @@ def forward(
16221652 attn_metadata .seq_lens .to (torch .int16 ),
16231653 attn_metadata .block_tables .to (torch .int16 ),
16241654 self .scale , # dummy
1655+ ]
1656+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1657+ decode_args .append (self .sinks )
1658+ attn_output = flash_attention_naive_decode ( # noqa: E501
1659+ * decode_args ,
16251660 )
16261661 else :
1627- attn_output = flash_attention_naive_prefill ( # noqa: E501
1662+ prefill_args = [
16281663 query ,
16291664 key ,
16301665 value ,
@@ -1634,6 +1669,11 @@ def forward(
16341669 attn_metadata .seq_lens .to (torch .int16 ),
16351670 attn_metadata .block_tables .to (torch .int16 ),
16361671 self .scale , # dummy
1672+ ]
1673+ if not envs .VLLM_RBLN_USE_CUSTOM_KERNEL :
1674+ prefill_args .append (self .sinks )
1675+ attn_output = flash_attention_naive_prefill ( # noqa: E501
1676+ * prefill_args ,
16371677 )
16381678
16391679 # 2. attention output reshape for attention backend return
0 commit comments