Skip to content

Commit e553547

Browse files
committed
added a jax ref impl which will be used in google3
1 parent 92672e4 commit e553547

File tree

1 file changed

+155
-20
lines changed

1 file changed

+155
-20
lines changed

test/test_pallas.py

+155-20
Original file line numberDiff line numberDiff line change
@@ -800,31 +800,26 @@ def test_extended_paged_attention_v1_multiple_queries(self):
800800
print('my new extended paged attention finished yay')
801801

802802
# Run Woosuk's non-kernel impl.
803-
ref_q_torch = q.detach().clone()
804-
assert ref_q_torch.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {ref_q_torch.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
805-
assert jnp.allclose(q_jax, jnp.array(ref_q_torch.numpy(), dtype=jnp.float32))
806-
ref_k_pages_torch = k_pages.detach().clone()
807-
assert jnp.allclose(k_pages_jax, jnp.array(ref_k_pages_torch.numpy(), dtype=jnp.float32))
808-
ref_v_pages_torch = v_pages.detach().clone()
809-
assert jnp.allclose(v_pages_jax, jnp.array(ref_v_pages_torch.numpy(), dtype=jnp.float32))
810-
ref_kv_seq_lens_torch = kv_seq_lengths.detach().clone()
811-
assert jnp.allclose(kv_seq_lens_jax, jnp.array(ref_kv_seq_lens_torch.numpy(), dtype=jnp.int32))
812-
ref_page_indices_torch = page_indices.detach().clone()
813-
assert jnp.allclose(page_indices_jax, jnp.array(ref_page_indices_torch.numpy(), dtype=jnp.int32))
803+
assert q.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {q.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
804+
assert jnp.allclose(q_jax, jnp.array(q.numpy(), dtype=jnp.float32))
805+
assert jnp.allclose(k_pages_jax, jnp.array(k_pages.numpy(), dtype=jnp.float32))
806+
assert jnp.allclose(v_pages_jax, jnp.array(v_pages.numpy(), dtype=jnp.float32))
807+
assert jnp.allclose(kv_seq_lens_jax, jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32))
808+
assert jnp.allclose(page_indices_jax, jnp.array(page_indices.numpy(), dtype=jnp.int32))
814809

815810
expected_output = ref_extended_paged_attention(
816-
ref_q_torch,
817-
ref_k_pages_torch,
818-
ref_v_pages_torch,
819-
ref_kv_seq_lens_torch,
820-
ref_page_indices_torch,
811+
q,
812+
k_pages,
813+
v_pages,
814+
kv_seq_lengths,
815+
page_indices,
821816
)
822817

823818
expected_output_cpu=expected_output.cpu()
824819
# Need to squeeze out the query_len dimension!
825820
actual_output_cpu=actual_output.cpu()
826-
# print(f'{expected_output_cpu=}')
827-
# print(f'{actual_output_cpu=}')
821+
print(f'{expected_output_cpu=}')
822+
print(f'{actual_output_cpu=}')
828823
print(f'actual_output_cpu.shape={actual_output_cpu.shape}')
829824
print(f'expected_output_cpu.shape={expected_output_cpu.shape}')
830825
self.assertEqual(actual_output_cpu.shape, expected_output_cpu.shape)
@@ -836,8 +831,148 @@ def test_extended_paged_attention_v1_multiple_queries(self):
836831
torch.allclose(
837832
expected_output_cpu,
838833
actual_output_cpu,
839-
atol=1e-5,
840-
rtol=1e-5))
834+
atol=1e-2,
835+
rtol=1e-2))
836+
837+
def ref_jax_extended_paged_attention(
838+
self,
839+
q, # [batch_size, query_len, num_query_heads, head_size]
840+
k_pages,# [num_kv_heads, total_num_pages, page_size, head_size]
841+
v_pages,# [num_kv_heads, total_num_pages, page_size, head_size]
842+
lengths,# [batch_size]
843+
page_indices,# [batch_size, pages_per_sequence]
844+
):
845+
batch_size, query_len, num_query_heads, head_size = q.shape
846+
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
847+
num_query_per_kv = num_query_heads // num_kv_heads
848+
849+
lengths = lengths
850+
page_indices = page_indices
851+
outputs = []
852+
for i in range(batch_size):
853+
kv_len = lengths[i]
854+
num_pages = (kv_len+page_size-1) // page_size
855+
indices = page_indices[i, :num_pages]
856+
857+
k = k_pages[:, indices]
858+
k = jnp.permute_dims(k, (1, 2, 0, 3))
859+
k = jnp.reshape(k, (num_pages * page_size, num_kv_heads, head_size))
860+
k = k[:kv_len]
861+
862+
v = v_pages[:, indices]
863+
v = jnp.permute_dims(v, (1, 2, 0, 3))
864+
v = jnp.reshape(v, (num_pages * page_size, num_kv_heads, head_size))
865+
v = v[:kv_len]
866+
867+
if num_query_per_kv != 1:
868+
k = jnp.repeat(k, num_query_per_kv, axis=1)
869+
v = jnp.repeat(v, num_query_per_kv, axis=1)
870+
871+
attn = jnp.einsum("qhd,khd->hqk", q[i], k)
872+
attn = attn.astype('float32')
873+
q_span = (kv_len-query_len) + jax.lax.broadcasted_iota(
874+
jnp.int32, (query_len, kv_len), 0
875+
)
876+
kv_span = jax.lax.broadcasted_iota(
877+
jnp.int32, (query_len, kv_len), 1
878+
)
879+
mask=jnp.where(q_span < kv_span, float("-inf"), 0.)
880+
attn = attn + mask
881+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
882+
out = jnp.einsum("hqk,khd->qhd", attn, v)
883+
outputs.append(out)
884+
output = jnp.stack(outputs, axis=0)
885+
return output
886+
887+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
888+
"This test only works on TPUv4+.")
889+
def test_extended_paged_attention_jax_ref_impl(self):
890+
# Use multiple queries, verify my jax ref impl of paged attention against
891+
# Woosuk's pytorch ref impl. Should get the same result.
892+
from torch_xla.experimental.custom_kernel import ref_extended_paged_attention
893+
894+
batch_size: int = 3
895+
head_size: int = 128
896+
dtype_torch: torch.dtype = torch.float32
897+
max_kv_len: int = 1024
898+
total_num_pages: int = 32
899+
pages_per_sequence = total_num_pages
900+
901+
# num_compute_blks_q=1, num_compute_blks_kv=1,num_q_heads_per_kv_head=8
902+
# num_compute_blks_q=(query_len//num_queries_per_compute_block)
903+
# Change num_queries_per_compute_block to adjust num_compute_blks_q
904+
# num_compute_blks_kv=(pages_per_sequence//num_kv_pages_per_compute_block) where
905+
# pages_per_sequence is the same as total_num_pages
906+
# Change pallas_compute_block_size to adjust num_compute_blks_kv
907+
pallas_compute_block_size = 2048
908+
page_size: int = 64
909+
num_kv_pages_per_compute_block=pallas_compute_block_size // page_size
910+
query_len: int = 8
911+
num_queries_per_compute_block=8
912+
num_query_heads: int = 64
913+
num_kv_heads: int = 8
914+
915+
assert num_query_heads % num_kv_heads == 0
916+
assert query_len <= max_kv_len
917+
assert max_kv_len <= total_num_pages * page_size
918+
919+
print(f'The test test_extended_paged_attention_multiple_queries begins with {query_len=}')
920+
q = torch.randn(batch_size, query_len, num_query_heads, head_size, dtype=dtype_torch)
921+
k_pages = torch.randn(num_kv_heads, total_num_pages, page_size, head_size, dtype=dtype_torch)
922+
v_pages = torch.rand_like(k_pages)
923+
kv_seq_lengths = torch.randint(query_len, max_kv_len + 1, (batch_size,))
924+
page_indices = torch.randint(0, total_num_pages, (batch_size, total_num_pages))
925+
926+
# Run the jax ref impl with query_len>1
927+
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
928+
assert q_jax.shape==(batch_size, query_len, num_query_heads, head_size), f"Input q_jax has the wrong shape: {q_jax.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
929+
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
930+
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
931+
kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
932+
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
933+
print('xw32 calling jax_extended_paged_attention1')
934+
actual_output = self.ref_jax_extended_paged_attention(
935+
q_jax,
936+
k_pages_jax,
937+
v_pages_jax,
938+
kv_seq_lens_jax,
939+
page_indices_jax,
940+
)
941+
942+
943+
# Run Woosuk's non-kernel impl.
944+
assert q.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {q.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
945+
assert jnp.allclose(q_jax, jnp.array(q.numpy(), dtype=jnp.float32))
946+
assert jnp.allclose(k_pages_jax, jnp.array(k_pages.numpy(), dtype=jnp.float32))
947+
assert jnp.allclose(v_pages_jax, jnp.array(v_pages.numpy(), dtype=jnp.float32))
948+
assert jnp.allclose(kv_seq_lens_jax, jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32))
949+
assert jnp.allclose(page_indices_jax, jnp.array(page_indices.numpy(), dtype=jnp.int32))
950+
951+
expected_output = ref_extended_paged_attention(
952+
q,
953+
k_pages,
954+
v_pages,
955+
kv_seq_lengths,
956+
page_indices,
957+
)
958+
959+
expected_output_cpu=expected_output.cpu()
960+
actual_output_cpu=torch.from_numpy(np.array(actual_output))
961+
print(f'{expected_output_cpu=}')
962+
print(f'{actual_output_cpu=}')
963+
print(f'actual_output_cpu.shape={actual_output_cpu.shape}')
964+
print(f'expected_output_cpu.shape={expected_output_cpu.shape}')
965+
self.assertEqual(actual_output_cpu.shape, expected_output_cpu.shape)
966+
# torch.set_printoptions(profile="full")
967+
print(f'{(actual_output_cpu-expected_output_cpu).abs()}')
968+
print(f'Output max diff: {(expected_output_cpu - actual_output_cpu).abs().max().item()}')
969+
print(f'Output mean diff: {(expected_output_cpu - actual_output_cpu).abs().mean().item()}')
970+
self.assertTrue(
971+
torch.allclose(
972+
expected_output_cpu,
973+
actual_output_cpu,
974+
atol=3e-2,
975+
rtol=2e-2))
841976

842977
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
843978
"This test only works on TPUv4+.")

0 commit comments

Comments
 (0)