Skip to content

Commit 6f788a8

Browse files
fix: add sink argument for contrib.custom_ops (#416)
1 parent 0a98aba commit 6f788a8

1 file changed

Lines changed: 48 additions & 8 deletions

File tree

vllm_rbln/v1/attention/backends/flash_attention.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ def forward(
14741474
)
14751475

14761476
if q_len == 1:
1477-
attn_output = causal_attention_naive_decode( # noqa: E501
1477+
decode_args = [
14781478
query,
14791479
key,
14801480
value,
@@ -1483,9 +1483,14 @@ def forward(
14831483
self.scale,
14841484
attn_metadata.block_tables.to(torch.int16),
14851485
self.scale, # dummy (required by rbln_triton_ops signature)
1486+
]
1487+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1488+
decode_args.append(self.sinks)
1489+
attn_output = causal_attention_naive_decode( # noqa: E501
1490+
*decode_args,
14861491
)
14871492
else:
1488-
attn_output = causal_attention_naive_prefill( # noqa: E501
1493+
prefill_args = [
14891494
query,
14901495
key,
14911496
value,
@@ -1494,6 +1499,11 @@ def forward(
14941499
self.scale,
14951500
attn_metadata.block_tables.to(torch.int16),
14961501
self.scale, # dummy (required by rbln_triton_ops signature)
1502+
]
1503+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1504+
prefill_args.append(self.sinks)
1505+
attn_output = causal_attention_naive_prefill( # noqa: E501
1506+
*prefill_args,
14971507
)
14981508
else:
14991509
if envs.VLLM_RBLN_COMPILE_MODEL:
@@ -1524,7 +1534,7 @@ def forward(
15241534
# * otherwise - seq_lens[B, P] == dyn_size_for_partitions,
15251535
# dynamic size for each partition
15261536
if q_len == 1:
1527-
attn_output = flash_causal_attention_naive_decode( # noqa: E501
1537+
decode_args = [
15281538
query,
15291539
key,
15301540
value,
@@ -1533,9 +1543,14 @@ def forward(
15331543
attn_metadata.seq_lens.to(torch.int16),
15341544
attn_metadata.block_tables.to(torch.int16),
15351545
self.scale, # dummy
1546+
]
1547+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1548+
decode_args.append(self.sinks)
1549+
attn_output = flash_causal_attention_naive_decode( # noqa: E501
1550+
*decode_args,
15361551
)
15371552
else:
1538-
attn_output = flash_causal_attention_naive_prefill( # noqa: E501
1553+
prefill_args = [
15391554
query,
15401555
key,
15411556
value,
@@ -1544,6 +1559,11 @@ def forward(
15441559
attn_metadata.seq_lens.to(torch.int16),
15451560
attn_metadata.block_tables.to(torch.int16),
15461561
self.scale, # dummy
1562+
]
1563+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1564+
prefill_args.append(self.sinks)
1565+
attn_output = flash_causal_attention_naive_prefill( # noqa: E501
1566+
*prefill_args,
15471567
)
15481568
else:
15491569
if self.is_normal:
@@ -1568,7 +1588,7 @@ def forward(
15681588
)
15691589

15701590
if q_len == 1:
1571-
attn_output = attention_naive_decode( # noqa: E501
1591+
decode_args = [
15721592
query,
15731593
key,
15741594
value,
@@ -1578,9 +1598,14 @@ def forward(
15781598
self.scale,
15791599
attn_metadata.block_tables.to(torch.int16),
15801600
self.scale, # dummy (required by rbln_triton_ops signature)
1601+
]
1602+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1603+
decode_args.append(self.sinks)
1604+
attn_output = attention_naive_decode( # noqa: E501
1605+
*decode_args,
15811606
)
15821607
else:
1583-
attn_output = attention_naive_prefill( # noqa: E501
1608+
prefill_args = [
15841609
query,
15851610
key,
15861611
value,
@@ -1590,6 +1615,11 @@ def forward(
15901615
self.scale,
15911616
attn_metadata.block_tables.to(torch.int16),
15921617
self.scale, # dummy (required by rbln_triton_ops signature)
1618+
]
1619+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1620+
prefill_args.append(self.sinks)
1621+
attn_output = attention_naive_prefill( # noqa: E501
1622+
*prefill_args,
15931623
)
15941624
else:
15951625
if envs.VLLM_RBLN_COMPILE_MODEL:
@@ -1612,7 +1642,7 @@ def forward(
16121642
flash_attention_naive_decode = flash_attention_naive_decode_impl
16131643

16141644
if q_len == 1:
1615-
attn_output = flash_attention_naive_decode( # noqa: E501
1645+
decode_args = [
16161646
query,
16171647
key,
16181648
value,
@@ -1622,9 +1652,14 @@ def forward(
16221652
attn_metadata.seq_lens.to(torch.int16),
16231653
attn_metadata.block_tables.to(torch.int16),
16241654
self.scale, # dummy
1655+
]
1656+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1657+
decode_args.append(self.sinks)
1658+
attn_output = flash_attention_naive_decode( # noqa: E501
1659+
*decode_args,
16251660
)
16261661
else:
1627-
attn_output = flash_attention_naive_prefill( # noqa: E501
1662+
prefill_args = [
16281663
query,
16291664
key,
16301665
value,
@@ -1634,6 +1669,11 @@ def forward(
16341669
attn_metadata.seq_lens.to(torch.int16),
16351670
attn_metadata.block_tables.to(torch.int16),
16361671
self.scale, # dummy
1672+
]
1673+
if not envs.VLLM_RBLN_USE_CUSTOM_KERNEL:
1674+
prefill_args.append(self.sinks)
1675+
attn_output = flash_attention_naive_prefill( # noqa: E501
1676+
*prefill_args,
16371677
)
16381678

16391679
# 2. attention output reshape for attention backend return

0 commit comments

Comments
 (0)