Skip to content

[Kernel] support kv cache quantization in ragged attention kernel #9249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
from torch_xla.experimental.custom_kernel import jax_import_guard, convert_torch_dtype_to_jax
jax_import_guard()
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -98,7 +98,8 @@ def _ragged_pagedattention_generate_qkv(
head_dim,
page_size,
num_pages,
dtype,
q_dtype,
kv_dtype,
*,
max_num_batched_tokens=None,
max_num_seqs=16,
Expand Down Expand Up @@ -129,10 +130,11 @@ def _ragged_pagedattention_generate_qkv(
kv_lens = torch.nn.functional.pad(kv_lens,
(0, max_num_seqs - kv_lens.shape[0]),
"constant", 0)
# Use float32 for randn because it doesn't support some dtypes like float8
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
dtype=dtype)
dtype=torch.float32).to(q_dtype)
kv_pages = torch.randn((num_pages, page_size, num_kv_heads * 2, head_dim),
dtype=dtype)
dtype=torch.float32).to(kv_dtype)
page_indices = torch.randint(
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
return q, kv_pages, kv_lens, page_indices, cu_q_lens
Expand Down Expand Up @@ -632,7 +634,8 @@ def _test_ragged_paged_attention(
head_dim,
page_size,
num_pages,
dtype,
q_dtype,
kv_dtype,
*,
sm_scale=1.0,
sliding_window=None,
Expand All @@ -654,9 +657,18 @@ def _test_ragged_paged_attention(
head_dim,
page_size,
num_pages,
dtype,
q_dtype,
kv_dtype,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs)
k_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
v_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
num_kv_heads = num_heads[1]
if num_kv_heads == 1 and kv_dtype in [torch.float8_e5m2]:
self.skipTest(
"attention kernel cannot support because it is not XLA fully tiled")
if kv_dtype is torch.float8_e5m2 and tpu.version() <= 4:
self.skipTest("TPU v4 or older doesn't support fp8")

q_xla = q.to("xla")
kv_pages_xla = kv_pages.to("xla")
Expand All @@ -677,6 +689,8 @@ def ragged_paged_attention_wrapper(
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale,
use_kernel=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
Expand All @@ -691,6 +705,8 @@ def ragged_paged_attention_wrapper(
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale,
use_kernel=use_kernel,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
Expand All @@ -711,6 +727,8 @@ def ragged_paged_attention_wrapper(
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale,
use_kernel=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
Expand All @@ -726,6 +744,8 @@ def ragged_paged_attention_wrapper(
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale,
use_kernel=False,
)

Expand All @@ -734,17 +754,14 @@ def ragged_paged_attention_wrapper(
self.assertEqual(kernel_output_cpu.shape, nonkernel_output_cpu.shape)
self.assertEqual(kernel_output_cpu.dtype, nonkernel_output_cpu.dtype)

assert dtype == torch.float32 or dtype == torch.bfloat16
jnp_dtype = jnp.float32
tol = 0.15
if dtype == torch.bfloat16:
jnp_dtype = jnp.bfloat16
tol = 0.3
tol = 0.15 if q_dtype == torch.float32 else 0.3
q_jnp_dtype = convert_torch_dtype_to_jax(q_dtype)
kv_jnp_dtype = convert_torch_dtype_to_jax(kv_dtype)

# Numpy does not support bfloat16 directly. So we convert f32 first.
q_jax = jnp.array(q.to(torch.float32).numpy(), dtype=jnp_dtype)
q_jax = jnp.array(q.to(torch.float32).numpy(), dtype=q_jnp_dtype)
kv_pages_jax = jnp.array(
kv_pages.to(torch.float32).numpy(), dtype=jnp_dtype)
kv_pages.to(torch.float32).numpy(), dtype=kv_jnp_dtype)
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
Expand All @@ -765,7 +782,9 @@ def ragged_paged_attention_wrapper(
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(dtype)
k_scale=k_scale,
v_scale=v_scale,
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(q_dtype)
jax_kernel_output_cpu = jax_kernel_output.cpu()

torch.testing.assert_close(
Expand All @@ -776,7 +795,8 @@ def ragged_paged_attention_wrapper(
@parameterized.product(
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
num_heads=[(32, 8), (8, 1)],
dtype=[torch.float32, torch.bfloat16],
dtype=[(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float8_e5m2)],
sm_scale=[1.0, 0.5],
sliding_window=[None, 128],
soft_cap=[None, 10.0],
Expand All @@ -796,14 +816,16 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
head_dim = 128
page_size = 16
num_pages = 1000
q_dtype, kv_dtype = dtype

self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
q_dtype,
kv_dtype,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
Expand All @@ -814,7 +836,8 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
@parameterized.product(
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
num_heads=[(32, 8), (8, 1)],
dtype=[torch.float32, torch.bfloat16],
dtype=[(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float8_e5m2)],
sm_scale=[1.0, 0.5],
sliding_window=[None, 128],
soft_cap=[None, 10.0],
Expand All @@ -835,14 +858,16 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
head_dim = 128
page_size = 16
num_pages = 1000
q_dtype, kv_dtype = dtype

self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
q_dtype,
kv_dtype,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
Expand Down
30 changes: 29 additions & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
elif dtype == torch.float8_e5m2:
return jnp.float8_e5m2
elif dtype == torch.float8_e4m3fn:
return jnp.float8_e4m3fn
elif dtype == torch.float8_e4m3fnuz:
return jnp.float8_e4m3fnuz
else:
raise ValueError(f"Unsupported dtype: {dtype}")

Expand Down Expand Up @@ -901,6 +907,8 @@ def _ragged_paged_attention_nonkernel(
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
num_seqs, # i32[1]
k_scale,
v_scale,
*,
sm_scale=1.0,
sliding_window: int | None = None,
Expand All @@ -927,6 +935,12 @@ def _ragged_paged_attention_nonkernel(
head_dim)[:kv_len]
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
head_dim)[:kv_len]
if k_scale is not None:
k = k.to(torch.float32) * k_scale
k = k.to(q.dtype)
if v_scale is not None:
v = v.to(torch.float32) * v_scale
v = v.to(q.dtype)
k = torch.repeat_interleave(k, num_query_per_kv, dim=1)
v = torch.repeat_interleave(v, num_query_per_kv, dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k)
Expand Down Expand Up @@ -963,6 +977,8 @@ def ragged_paged_attention(
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value=None,
k_scale: float | None = None,
v_scale: float | None = None,
use_kernel=True,
# kernel tuning parameters
num_kv_pages_per_block=None,
Expand All @@ -984,6 +1000,8 @@ def ragged_paged_attention(
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
)

# Import JAX within the function such that we don't need to call the jax_import_guard()
Expand All @@ -1005,6 +1023,8 @@ def ragged_paged_attention(
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes,
Expand All @@ -1016,6 +1036,8 @@ def ragged_paged_attention(
"num_kv_pages_per_block",
"num_queries_per_block",
"vmem_limit_bytes",
"k_scale",
"v_scale",
],
)

Expand Down Expand Up @@ -1485,7 +1507,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
XLA_LIB.define(
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True,"
"float? soft_cap=None, float? mask_value=None, float? k_scale=None, float? v_scale=None, bool use_kernel=True,"
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor",
)

Expand All @@ -1502,6 +1524,8 @@ def ragged_paged_attention_xla(
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value=None,
k_scale: float | None = None,
v_scale: float | None = None,
use_kernel=True,
# kernel tuning parameters
num_kv_pages_per_block=None,
Expand All @@ -1519,6 +1543,8 @@ def ragged_paged_attention_xla(
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
use_kernel=use_kernel,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
Expand All @@ -1537,6 +1563,8 @@ def ragged_paged_attention_non_xla(
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value=None,
k_scale: float | None = None,
v_scale: float | None = None,
use_kernel=True,
# kernel tuning parameters
num_kv_pages_per_block=None,
Expand Down
Loading
Loading