Skip to content

Commit 2edcd2e

Browse files
authored
[Kernel] support kv cache quantization in ragged attention kernel (#9249)
1 parent 248a8b3 commit 2edcd2e

File tree

3 files changed

+165
-57
lines changed

3 files changed

+165
-57
lines changed

test/test_pallas.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515

1616
if xr.device_type() == 'TPU':
17-
from torch_xla.experimental.custom_kernel import jax_import_guard
17+
from torch_xla.experimental.custom_kernel import jax_import_guard, convert_torch_dtype_to_jax
1818
jax_import_guard()
1919
import jax
2020
import jax.numpy as jnp
@@ -98,7 +98,8 @@ def _ragged_pagedattention_generate_qkv(
9898
head_dim,
9999
page_size,
100100
num_pages,
101-
dtype,
101+
q_dtype,
102+
kv_dtype,
102103
*,
103104
max_num_batched_tokens=None,
104105
max_num_seqs=16,
@@ -129,10 +130,11 @@ def _ragged_pagedattention_generate_qkv(
129130
kv_lens = torch.nn.functional.pad(kv_lens,
130131
(0, max_num_seqs - kv_lens.shape[0]),
131132
"constant", 0)
133+
# Use float32 for randn because it doesn't support some dtypes like float8
132134
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
133-
dtype=dtype)
135+
dtype=torch.float32).to(q_dtype)
134136
kv_pages = torch.randn((num_pages, page_size, num_kv_heads * 2, head_dim),
135-
dtype=dtype)
137+
dtype=torch.float32).to(kv_dtype)
136138
page_indices = torch.randint(
137139
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
138140
return q, kv_pages, kv_lens, page_indices, cu_q_lens
@@ -632,7 +634,8 @@ def _test_ragged_paged_attention(
632634
head_dim,
633635
page_size,
634636
num_pages,
635-
dtype,
637+
q_dtype,
638+
kv_dtype,
636639
*,
637640
sm_scale=1.0,
638641
sliding_window=None,
@@ -654,9 +657,18 @@ def _test_ragged_paged_attention(
654657
head_dim,
655658
page_size,
656659
num_pages,
657-
dtype,
660+
q_dtype,
661+
kv_dtype,
658662
max_num_batched_tokens=max_num_batched_tokens,
659663
max_num_seqs=max_num_seqs)
664+
k_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
665+
v_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
666+
num_kv_heads = num_heads[1]
667+
if num_kv_heads == 1 and kv_dtype in [torch.float8_e5m2]:
668+
self.skipTest(
669+
"attention kernel cannot support because it is not XLA fully tiled")
670+
if kv_dtype is torch.float8_e5m2 and tpu.version() <= 4:
671+
self.skipTest("TPU v4 or older doesn't support fp8")
660672

661673
q_xla = q.to("xla")
662674
kv_pages_xla = kv_pages.to("xla")
@@ -677,6 +689,8 @@ def ragged_paged_attention_wrapper(
677689
sm_scale=sm_scale,
678690
sliding_window=sliding_window,
679691
soft_cap=soft_cap,
692+
k_scale=k_scale,
693+
v_scale=v_scale,
680694
use_kernel=True,
681695
num_kv_pages_per_block=num_kv_pages_per_block,
682696
num_queries_per_block=num_queries_per_block,
@@ -691,6 +705,8 @@ def ragged_paged_attention_wrapper(
691705
sm_scale=sm_scale,
692706
sliding_window=sliding_window,
693707
soft_cap=soft_cap,
708+
k_scale=k_scale,
709+
v_scale=v_scale,
694710
use_kernel=use_kernel,
695711
num_kv_pages_per_block=num_kv_pages_per_block,
696712
num_queries_per_block=num_queries_per_block,
@@ -711,6 +727,8 @@ def ragged_paged_attention_wrapper(
711727
sm_scale=sm_scale,
712728
sliding_window=sliding_window,
713729
soft_cap=soft_cap,
730+
k_scale=k_scale,
731+
v_scale=v_scale,
714732
use_kernel=True,
715733
num_kv_pages_per_block=num_kv_pages_per_block,
716734
num_queries_per_block=num_queries_per_block,
@@ -726,6 +744,8 @@ def ragged_paged_attention_wrapper(
726744
sm_scale=sm_scale,
727745
sliding_window=sliding_window,
728746
soft_cap=soft_cap,
747+
k_scale=k_scale,
748+
v_scale=v_scale,
729749
use_kernel=False,
730750
)
731751

@@ -734,17 +754,14 @@ def ragged_paged_attention_wrapper(
734754
self.assertEqual(kernel_output_cpu.shape, nonkernel_output_cpu.shape)
735755
self.assertEqual(kernel_output_cpu.dtype, nonkernel_output_cpu.dtype)
736756

737-
assert dtype == torch.float32 or dtype == torch.bfloat16
738-
jnp_dtype = jnp.float32
739-
tol = 0.15
740-
if dtype == torch.bfloat16:
741-
jnp_dtype = jnp.bfloat16
742-
tol = 0.3
757+
tol = 0.15 if q_dtype == torch.float32 else 0.3
758+
q_jnp_dtype = convert_torch_dtype_to_jax(q_dtype)
759+
kv_jnp_dtype = convert_torch_dtype_to_jax(kv_dtype)
743760

744761
# Numpy does not support bfloat16 directly. So we convert f32 first.
745-
q_jax = jnp.array(q.to(torch.float32).numpy(), dtype=jnp_dtype)
762+
q_jax = jnp.array(q.to(torch.float32).numpy(), dtype=q_jnp_dtype)
746763
kv_pages_jax = jnp.array(
747-
kv_pages.to(torch.float32).numpy(), dtype=jnp_dtype)
764+
kv_pages.to(torch.float32).numpy(), dtype=kv_jnp_dtype)
748765
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
749766
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
750767
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
@@ -765,7 +782,9 @@ def ragged_paged_attention_wrapper(
765782
sm_scale=sm_scale,
766783
sliding_window=sliding_window,
767784
soft_cap=soft_cap,
768-
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(dtype)
785+
k_scale=k_scale,
786+
v_scale=v_scale,
787+
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(q_dtype)
769788
jax_kernel_output_cpu = jax_kernel_output.cpu()
770789

771790
torch.testing.assert_close(
@@ -776,7 +795,8 @@ def ragged_paged_attention_wrapper(
776795
@parameterized.product(
777796
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
778797
num_heads=[(32, 8), (8, 1)],
779-
dtype=[torch.float32, torch.bfloat16],
798+
dtype=[(torch.bfloat16, torch.bfloat16),
799+
(torch.bfloat16, torch.float8_e5m2)],
780800
sm_scale=[1.0, 0.5],
781801
sliding_window=[None, 128],
782802
soft_cap=[None, 10.0],
@@ -796,14 +816,16 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
796816
head_dim = 128
797817
page_size = 16
798818
num_pages = 1000
819+
q_dtype, kv_dtype = dtype
799820

800821
self._test_ragged_paged_attention(
801822
seq_lens,
802823
num_heads,
803824
head_dim,
804825
page_size,
805826
num_pages,
806-
dtype,
827+
q_dtype,
828+
kv_dtype,
807829
sm_scale=sm_scale,
808830
sliding_window=sliding_window,
809831
soft_cap=soft_cap,
@@ -814,7 +836,8 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
814836
@parameterized.product(
815837
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
816838
num_heads=[(32, 8), (8, 1)],
817-
dtype=[torch.float32, torch.bfloat16],
839+
dtype=[(torch.bfloat16, torch.bfloat16),
840+
(torch.bfloat16, torch.float8_e5m2)],
818841
sm_scale=[1.0, 0.5],
819842
sliding_window=[None, 128],
820843
soft_cap=[None, 10.0],
@@ -835,14 +858,16 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
835858
head_dim = 128
836859
page_size = 16
837860
num_pages = 1000
861+
q_dtype, kv_dtype = dtype
838862

839863
self._test_ragged_paged_attention(
840864
seq_lens,
841865
num_heads,
842866
head_dim,
843867
page_size,
844868
num_pages,
845-
dtype,
869+
q_dtype,
870+
kv_dtype,
846871
sm_scale=sm_scale,
847872
sliding_window=sliding_window,
848873
soft_cap=soft_cap,

torch_xla/experimental/custom_kernel.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
183183
return jnp.int8
184184
elif dtype == torch.uint8:
185185
return jnp.uint8
186+
elif dtype == torch.float8_e5m2:
187+
return jnp.float8_e5m2
188+
elif dtype == torch.float8_e4m3fn:
189+
return jnp.float8_e4m3fn
190+
elif dtype == torch.float8_e4m3fnuz:
191+
return jnp.float8_e4m3fnuz
186192
else:
187193
raise ValueError(f"Unsupported dtype: {dtype}")
188194

@@ -901,6 +907,8 @@ def _ragged_paged_attention_nonkernel(
901907
page_indices, # i32[max_num_seqs, pages_per_seq]
902908
cu_q_lens, # i32[max_num_seqs + 1]
903909
num_seqs, # i32[1]
910+
k_scale,
911+
v_scale,
904912
*,
905913
sm_scale=1.0,
906914
sliding_window: int | None = None,
@@ -927,6 +935,12 @@ def _ragged_paged_attention_nonkernel(
927935
head_dim)[:kv_len]
928936
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
929937
head_dim)[:kv_len]
938+
if k_scale is not None:
939+
k = k.to(torch.float32) * k_scale
940+
k = k.to(q.dtype)
941+
if v_scale is not None:
942+
v = v.to(torch.float32) * v_scale
943+
v = v.to(q.dtype)
930944
k = torch.repeat_interleave(k, num_query_per_kv, dim=1)
931945
v = torch.repeat_interleave(v, num_query_per_kv, dim=1)
932946
attn = torch.einsum("qhd,khd->hqk", q, k)
@@ -963,6 +977,8 @@ def ragged_paged_attention(
963977
sliding_window: int | None = None,
964978
soft_cap: float | None = None,
965979
mask_value=None,
980+
k_scale: float | None = None,
981+
v_scale: float | None = None,
966982
use_kernel=True,
967983
# kernel tuning parameters
968984
num_kv_pages_per_block=None,
@@ -984,6 +1000,8 @@ def ragged_paged_attention(
9841000
sliding_window=sliding_window,
9851001
soft_cap=soft_cap,
9861002
mask_value=mask_value,
1003+
k_scale=k_scale,
1004+
v_scale=v_scale,
9871005
)
9881006

9891007
# Import JAX within the function such that we don't need to call the jax_import_guard()
@@ -1005,6 +1023,8 @@ def ragged_paged_attention(
10051023
sliding_window=sliding_window,
10061024
soft_cap=soft_cap,
10071025
mask_value=mask_value,
1026+
k_scale=k_scale,
1027+
v_scale=v_scale,
10081028
num_kv_pages_per_block=num_kv_pages_per_block,
10091029
num_queries_per_block=num_queries_per_block,
10101030
vmem_limit_bytes=vmem_limit_bytes,
@@ -1016,6 +1036,8 @@ def ragged_paged_attention(
10161036
"num_kv_pages_per_block",
10171037
"num_queries_per_block",
10181038
"vmem_limit_bytes",
1039+
"k_scale",
1040+
"v_scale",
10191041
],
10201042
)
10211043

@@ -1492,7 +1514,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
14921514
XLA_LIB.define(
14931515
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
14941516
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
1495-
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True,"
1517+
"float? soft_cap=None, float? mask_value=None, float? k_scale=None, float? v_scale=None, bool use_kernel=True,"
14961518
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor",
14971519
)
14981520

@@ -1509,6 +1531,8 @@ def ragged_paged_attention_xla(
15091531
sliding_window: int | None = None,
15101532
soft_cap: float | None = None,
15111533
mask_value=None,
1534+
k_scale: float | None = None,
1535+
v_scale: float | None = None,
15121536
use_kernel=True,
15131537
# kernel tuning parameters
15141538
num_kv_pages_per_block=None,
@@ -1526,6 +1550,8 @@ def ragged_paged_attention_xla(
15261550
sliding_window=sliding_window,
15271551
soft_cap=soft_cap,
15281552
mask_value=mask_value,
1553+
k_scale=k_scale,
1554+
v_scale=v_scale,
15291555
use_kernel=use_kernel,
15301556
num_kv_pages_per_block=num_kv_pages_per_block,
15311557
num_queries_per_block=num_queries_per_block,
@@ -1544,6 +1570,8 @@ def ragged_paged_attention_non_xla(
15441570
sliding_window: int | None = None,
15451571
soft_cap: float | None = None,
15461572
mask_value=None,
1573+
k_scale: float | None = None,
1574+
v_scale: float | None = None,
15471575
use_kernel=True,
15481576
# kernel tuning parameters
15491577
num_kv_pages_per_block=None,

0 commit comments

Comments
 (0)