Skip to content

Commit 24ffa11

Browse files
committed
remove decode mode in rpa
1 parent 39e9537 commit 24ffa11

9 files changed

Lines changed: 11 additions & 223 deletions

File tree

benchmark/kernels/flash_attention/bench_flashattention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def jitted_attn(
112112
cu_kv_lens,
113113
distribution,
114114
custom_mask=None,
115-
decode_mode=0,
116115
causal=1,
117116
sm_scale=sm_scale,
118117
)

benchmark/kernels/flash_attention/get_block_spec_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def jitted_attn(
105105
cu_kv_lens,
106106
distribution,
107107
None,
108-
decode_mode=0,
109108
sm_scale=sm_scale,
110109
num_kv_pages_per_block=num_kv_pages_per_block,
111110
num_queries_per_block=num_queries_per_block,
-160 KB
Binary file not shown.
-160 KB
Binary file not shown.
-160 KB
Binary file not shown.
-160 KB
Binary file not shown.

python/sgl_jax/srt/kernels/ragged_paged_attention/ragged_paged_attention.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from jax.experimental.pallas import tpu as pltpu
1919

2020
from sgl_jax.srt.kernels.ragged_paged_attention.tuned_block_sizes import (
21-
get_kv_pages_for_decode,
2221
get_tuned_block_sizes,
2322
)
2423
from sgl_jax.srt.kernels.ragged_paged_attention.util import (
@@ -1381,7 +1380,6 @@ def get_kernel_scope_name(bq_size, bkv_p, page_size):
13811380
jax.jit,
13821381
static_argnames=(
13831382
"causal",
1384-
"decode_mode",
13851383
"sm_scale",
13861384
"sliding_window",
13871385
"soft_cap",
@@ -1410,7 +1408,6 @@ def ragged_paged_attention(
14101408
custom_mask: jax.Array, # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]
14111409
*,
14121410
causal: int = 1, # 1: True, 0: False
1413-
decode_mode: int = 1,
14141411
sm_scale: float = 1.0,
14151412
sliding_window: int | None = None,
14161413
soft_cap: float | None = None,
@@ -1511,29 +1508,17 @@ def ragged_paged_attention(
15111508
bkv_p = num_kv_pages_per_block
15121509
bq_sz = num_queries_per_block
15131510
if bq_sz is None or bkv_p is None:
1514-
if decode_mode == 1:
1515-
bkv_p = get_kv_pages_for_decode(
1516-
q.dtype,
1517-
kv_cache_fused_processed.dtype,
1518-
actual_num_q_heads,
1519-
actual_num_kv_heads,
1520-
head_dim,
1521-
page_size,
1522-
pages_per_seq,
1523-
)
1524-
bq_sz = 1
1525-
else:
1526-
bkv_p, bq_sz = get_tuned_block_sizes(
1527-
q.dtype,
1528-
kv_cache_fused_processed.dtype,
1529-
actual_num_q_heads,
1530-
actual_num_kv_heads,
1531-
head_dim,
1532-
page_size,
1533-
max_num_tokens,
1534-
pages_per_seq,
1535-
causal,
1536-
)
1511+
bkv_p, bq_sz = get_tuned_block_sizes(
1512+
q.dtype,
1513+
kv_cache_fused_processed.dtype,
1514+
actual_num_q_heads,
1515+
actual_num_kv_heads,
1516+
head_dim,
1517+
page_size,
1518+
max_num_tokens,
1519+
pages_per_seq,
1520+
causal,
1521+
)
15371522
kv_packing = get_dtype_packing(kv_cache_fused_processed.dtype)
15381523
if page_size == 1:
15391524
bkv_p = bkv_p // 2

python/sgl_jax/srt/kernels/ragged_paged_attention/tuned_block_sizes.py

Lines changed: 0 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,195 +1514,3 @@ def get_simplified_key(
15141514
next_power_of_2(page_size),
15151515
next_power_of_2(max_num_tokens),
15161516
)
1517-
1518-
1519-
TUNED_KV_PAGES_FOR_DECODE = {
1520-
# key
1521-
# - device_name
1522-
# - q dtype
1523-
# - kv dtype
1524-
# - q head number
1525-
# - kv head number
1526-
# - head dim
1527-
# - page_size
1528-
# value:
1529-
# - num_kv_pages_per_block
1530-
"TPU v6e": {
1531-
("bfloat16", "bfloat16", 1, 1, 128, 128): 256,
1532-
("bfloat16", "bfloat16", 2, 1, 128, 128): 256,
1533-
("bfloat16", "bfloat16", 4, 1, 128, 128): 256,
1534-
("bfloat16", "bfloat16", 8, 1, 128, 128): 256,
1535-
("bfloat16", "bfloat16", 16, 1, 128, 128): 256,
1536-
("bfloat16", "bfloat16", 2, 2, 128, 128): 128,
1537-
("bfloat16", "bfloat16", 4, 2, 128, 128): 128,
1538-
("bfloat16", "bfloat16", 8, 2, 128, 128): 128,
1539-
("bfloat16", "bfloat16", 16, 2, 128, 128): 128,
1540-
("bfloat16", "bfloat16", 32, 2, 128, 128): 128,
1541-
("bfloat16", "bfloat16", 4, 4, 128, 128): 64,
1542-
("bfloat16", "bfloat16", 8, 4, 128, 128): 64,
1543-
("bfloat16", "bfloat16", 16, 4, 128, 128): 64,
1544-
("bfloat16", "bfloat16", 32, 4, 128, 128): 64,
1545-
("bfloat16", "bfloat16", 8, 8, 128, 128): 32,
1546-
("bfloat16", "bfloat16", 16, 8, 128, 128): 32,
1547-
("bfloat16", "bfloat16", 32, 8, 128, 128): 32,
1548-
("bfloat16", "bfloat16", 64, 8, 128, 128): 32,
1549-
("bfloat16", "bfloat16", 16, 16, 128, 128): 16,
1550-
("bfloat16", "bfloat16", 32, 16, 128, 128): 16,
1551-
("bfloat16", "bfloat16", 64, 16, 128, 128): 16,
1552-
("bfloat16", "bfloat16", 128, 16, 128, 128): 16,
1553-
("bfloat16", "bfloat16", 1, 1, 128, 256): 128,
1554-
("bfloat16", "bfloat16", 2, 1, 128, 256): 128,
1555-
("bfloat16", "bfloat16", 4, 1, 128, 256): 128,
1556-
("bfloat16", "bfloat16", 8, 1, 128, 256): 128,
1557-
("bfloat16", "bfloat16", 16, 1, 128, 256): 128,
1558-
("bfloat16", "bfloat16", 2, 2, 128, 256): 64,
1559-
("bfloat16", "bfloat16", 4, 2, 128, 256): 64,
1560-
("bfloat16", "bfloat16", 8, 2, 128, 256): 64,
1561-
("bfloat16", "bfloat16", 16, 2, 128, 256): 64,
1562-
("bfloat16", "bfloat16", 4, 4, 128, 256): 32,
1563-
("bfloat16", "bfloat16", 8, 4, 128, 256): 32,
1564-
("bfloat16", "bfloat16", 16, 4, 128, 256): 32,
1565-
("bfloat16", "bfloat16", 32, 4, 128, 256): 32,
1566-
("bfloat16", "bfloat16", 8, 8, 128, 256): 16,
1567-
("bfloat16", "bfloat16", 16, 8, 128, 256): 16,
1568-
("bfloat16", "bfloat16", 32, 8, 128, 256): 16,
1569-
("bfloat16", "bfloat16", 64, 8, 128, 256): 16,
1570-
("bfloat16", "bfloat16", 16, 16, 128, 256): 8,
1571-
("bfloat16", "bfloat16", 32, 16, 128, 256): 8,
1572-
("bfloat16", "bfloat16", 64, 16, 128, 256): 8,
1573-
("bfloat16", "bfloat16", 128, 16, 128, 256): 8,
1574-
("bfloat16", "bfloat16", 256, 16, 128, 256): 8,
1575-
("bfloat16", "bfloat16", 512, 16, 128, 256): 8,
1576-
},
1577-
"TPU v7": {
1578-
("bfloat16", "bfloat16", 1, 1, 128, 128): 256,
1579-
("bfloat16", "bfloat16", 2, 1, 128, 128): 256,
1580-
("bfloat16", "bfloat16", 4, 1, 128, 128): 256,
1581-
("bfloat16", "bfloat16", 8, 1, 128, 128): 256,
1582-
("bfloat16", "bfloat16", 16, 1, 128, 128): 256,
1583-
("bfloat16", "bfloat16", 2, 2, 128, 128): 128,
1584-
("bfloat16", "bfloat16", 4, 2, 128, 128): 128,
1585-
("bfloat16", "bfloat16", 8, 2, 128, 128): 128,
1586-
("bfloat16", "bfloat16", 16, 2, 128, 128): 128,
1587-
("bfloat16", "bfloat16", 32, 2, 128, 128): 128,
1588-
("bfloat16", "bfloat16", 4, 4, 128, 128): 64,
1589-
("bfloat16", "bfloat16", 8, 4, 128, 128): 64,
1590-
("bfloat16", "bfloat16", 16, 4, 128, 128): 64,
1591-
("bfloat16", "bfloat16", 32, 4, 128, 128): 64,
1592-
("bfloat16", "bfloat16", 8, 8, 128, 128): 32,
1593-
("bfloat16", "bfloat16", 16, 8, 128, 128): 32,
1594-
("bfloat16", "bfloat16", 32, 8, 128, 128): 32,
1595-
("bfloat16", "bfloat16", 64, 8, 128, 128): 32,
1596-
("bfloat16", "bfloat16", 16, 16, 128, 128): 16,
1597-
("bfloat16", "bfloat16", 32, 16, 128, 128): 16,
1598-
("bfloat16", "bfloat16", 64, 16, 128, 128): 16,
1599-
("bfloat16", "bfloat16", 128, 16, 128, 128): 16,
1600-
("bfloat16", "bfloat16", 1, 1, 128, 256): 128,
1601-
("bfloat16", "bfloat16", 2, 1, 128, 256): 128,
1602-
("bfloat16", "bfloat16", 4, 1, 128, 256): 128,
1603-
("bfloat16", "bfloat16", 8, 1, 128, 256): 128,
1604-
("bfloat16", "bfloat16", 16, 1, 128, 256): 128,
1605-
("bfloat16", "bfloat16", 2, 2, 128, 256): 64,
1606-
("bfloat16", "bfloat16", 4, 2, 128, 256): 64,
1607-
("bfloat16", "bfloat16", 8, 2, 128, 256): 64,
1608-
("bfloat16", "bfloat16", 16, 2, 128, 256): 64,
1609-
("bfloat16", "bfloat16", 4, 4, 128, 256): 32,
1610-
("bfloat16", "bfloat16", 8, 4, 128, 256): 32,
1611-
("bfloat16", "bfloat16", 16, 4, 128, 256): 32,
1612-
("bfloat16", "bfloat16", 32, 4, 128, 256): 32,
1613-
("bfloat16", "bfloat16", 8, 8, 128, 256): 16,
1614-
("bfloat16", "bfloat16", 16, 8, 128, 256): 16,
1615-
("bfloat16", "bfloat16", 32, 8, 128, 256): 16,
1616-
("bfloat16", "bfloat16", 64, 8, 128, 256): 16,
1617-
("bfloat16", "bfloat16", 16, 16, 128, 256): 8,
1618-
("bfloat16", "bfloat16", 32, 16, 128, 256): 8,
1619-
("bfloat16", "bfloat16", 64, 16, 128, 256): 8,
1620-
("bfloat16", "bfloat16", 128, 16, 128, 256): 8,
1621-
("bfloat16", "bfloat16", 256, 16, 128, 256): 8,
1622-
("bfloat16", "bfloat16", 512, 16, 128, 256): 8,
1623-
},
1624-
}
1625-
1626-
1627-
def get_kv_pages_for_decode(
1628-
q_dtype,
1629-
kv_dtype,
1630-
actual_num_q_heads,
1631-
actual_num_kv_heads,
1632-
head_dim,
1633-
page_size,
1634-
pages_per_seq,
1635-
causal=True,
1636-
) -> int:
1637-
if not causal:
1638-
# FIXME(pc) hack this to avoid oom when precompile, currently, we still have no better choice for non-causal's mask
1639-
# this should be optimied future
1640-
return 4
1641-
"""Look up for the best num_kv_pages_per_blk from auto-tuned table."""
1642-
tpu_version = get_tpu_version()
1643-
1644-
if tpu_version < 4:
1645-
raise NotImplementedError("TPU version must be 4 or higher.")
1646-
keys = get_simplified_key_for_decode(
1647-
page_size,
1648-
q_dtype,
1649-
kv_dtype,
1650-
actual_num_q_heads,
1651-
actual_num_kv_heads,
1652-
head_dim,
1653-
)
1654-
1655-
device_name = keys[0]
1656-
1657-
# Default block sizes.
1658-
bkv_p = 1024 // page_size
1659-
if tpu_version == 4:
1660-
# TPUv4 has much smaller VMEM size so we pick fixed block sizes.
1661-
bkv_p = 512 // page_size
1662-
else:
1663-
if (
1664-
device_name in TUNED_KV_PAGES_FOR_DECODE
1665-
and keys[1:] in TUNED_KV_PAGES_FOR_DECODE[device_name]
1666-
):
1667-
bkv_p = TUNED_KV_PAGES_FOR_DECODE[device_name][keys[1:]]
1668-
else:
1669-
logger.info(
1670-
"Tuned RPA kv page not found for %s: page_size=%s, actual_num_q_heads=%s, "
1671-
"actual_num_kv_heads=%s, head_dim=%s, pages_per_seq=%s.",
1672-
device_name,
1673-
page_size,
1674-
actual_num_q_heads,
1675-
actual_num_kv_heads,
1676-
head_dim,
1677-
pages_per_seq,
1678-
)
1679-
logger.info("Using default block size: bkv_p=%s.", bkv_p)
1680-
1681-
return min(pages_per_seq, bkv_p)
1682-
1683-
1684-
def get_simplified_key_for_decode(
1685-
page_size,
1686-
q_dtype,
1687-
kv_dtype,
1688-
num_q_heads,
1689-
num_kv_heads,
1690-
head_dim,
1691-
):
1692-
"""Get the simplified key to reduce the number of combinations."""
1693-
assert num_q_heads % num_kv_heads == 0
1694-
device = get_device_name()
1695-
q_dtype_name = jnp.dtype(q_dtype).name
1696-
kv_dtype_name = jnp.dtype(kv_dtype).name
1697-
num_q_heads = next_power_of_2(num_q_heads)
1698-
num_kv_heads = next_power_of_2(num_kv_heads)
1699-
1700-
return (
1701-
device,
1702-
q_dtype_name,
1703-
kv_dtype_name,
1704-
num_q_heads,
1705-
num_kv_heads,
1706-
(head_dim + 127) // 128 * 128,
1707-
next_power_of_2(page_size),
1708-
)

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,6 @@ def __call__(
510510
if hasattr(token_to_kv_pool, "remap_cache_loc") and self.page_size == 1:
511511
page_indices_arg = token_to_kv_pool.remap_cache_loc(page_indices_arg, layer.layer_id)
512512

513-
decode_mode = 1 if forward_batch.forward_mode == ForwardMode.DECODE else 0
514-
515513
in_specs = (
516514
P(self.attention_data_partition_axis, self.kv_partition_axis), # queries
517515
P(self.attention_data_partition_axis, self.kv_partition_axis), # keys (new tokens)
@@ -545,7 +543,6 @@ def _ragged_paged_attention_with_fused_kv(*args):
545543
kv_cache_fused,
546544
*other_args,
547545
causal=causal,
548-
decode_mode=decode_mode,
549546
sm_scale=scale,
550547
sliding_window=layer.sliding_window_size,
551548
soft_cap=layer.logit_cap,

0 commit comments

Comments
 (0)