Skip to content

Commit 9834f06

Browse files
committed
added a jax ref impl which will be used in google3
1 parent 92672e4 commit 9834f06

File tree

1 file changed

+155
-19
lines changed

1 file changed

+155
-19
lines changed

test/test_pallas.py

+155-19
Original file line numberDiff line numberDiff line change
@@ -800,44 +800,180 @@ def test_extended_paged_attention_v1_multiple_queries(self):
800800
print('my new extended paged attention finished yay')
801801

802802
# Run Woosuk's non-kernel impl.
803-
ref_q_torch = q.detach().clone()
804-
assert ref_q_torch.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {ref_q_torch.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
805-
assert jnp.allclose(q_jax, jnp.array(ref_q_torch.numpy(), dtype=jnp.float32))
806-
ref_k_pages_torch = k_pages.detach().clone()
807-
assert jnp.allclose(k_pages_jax, jnp.array(ref_k_pages_torch.numpy(), dtype=jnp.float32))
808-
ref_v_pages_torch = v_pages.detach().clone()
809-
assert jnp.allclose(v_pages_jax, jnp.array(ref_v_pages_torch.numpy(), dtype=jnp.float32))
810-
ref_kv_seq_lens_torch = kv_seq_lengths.detach().clone()
811-
assert jnp.allclose(kv_seq_lens_jax, jnp.array(ref_kv_seq_lens_torch.numpy(), dtype=jnp.int32))
812-
ref_page_indices_torch = page_indices.detach().clone()
813-
assert jnp.allclose(page_indices_jax, jnp.array(ref_page_indices_torch.numpy(), dtype=jnp.int32))
803+
assert q.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {q.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
804+
assert jnp.allclose(q_jax, jnp.array(q.numpy(), dtype=jnp.float32))
805+
assert jnp.allclose(k_pages_jax, jnp.array(k_pages.numpy(), dtype=jnp.float32))
806+
assert jnp.allclose(v_pages_jax, jnp.array(v_pages.numpy(), dtype=jnp.float32))
807+
assert jnp.allclose(kv_seq_lens_jax, jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32))
808+
assert jnp.allclose(page_indices_jax, jnp.array(page_indices.numpy(), dtype=jnp.int32))
814809

815810
expected_output = ref_extended_paged_attention(
816-
ref_q_torch,
817-
ref_k_pages_torch,
818-
ref_v_pages_torch,
819-
ref_kv_seq_lens_torch,
820-
ref_page_indices_torch,
811+
q,
812+
k_pages,
813+
v_pages,
814+
kv_seq_lengths,
815+
page_indices,
821816
)
822817

823818
expected_output_cpu=expected_output.cpu()
824819
# Need to squeeze out the query_len dimension!
825820
actual_output_cpu=actual_output.cpu()
821+
print(f'{expected_output_cpu=}')
822+
print(f'{actual_output_cpu=}')
823+
print(f'actual_output_cpu.shape={actual_output_cpu.shape}')
824+
print(f'expected_output_cpu.shape={expected_output_cpu.shape}')
825+
self.assertEqual(actual_output_cpu.shape, expected_output_cpu.shape)
826+
# torch.set_printoptions(profile="full")
827+
print(f'{(actual_output_cpu-expected_output_cpu).abs()}')
828+
print(f'Output max diff: {(expected_output_cpu - actual_output_cpu).abs().max().item()}')
829+
print(f'Output mean diff: {(expected_output_cpu - actual_output_cpu).abs().mean().item()}')
830+
self.assertTrue(
831+
torch.allclose(
832+
expected_output_cpu,
833+
actual_output_cpu,
834+
atol=1e-2,
835+
rtol=1e-2))
836+
837+
def _ref_jax_extended_paged_attention(
838+
self,
839+
q, # [batch_size, query_len, num_query_heads, head_size]
840+
k_pages,# [num_kv_heads, total_num_pages, page_size, head_size]
841+
v_pages,# [num_kv_heads, total_num_pages, page_size, head_size]
842+
lengths,# [batch_size]
843+
page_indices,# [batch_size, pages_per_sequence]
844+
):
845+
batch_size, query_len, num_query_heads, head_size = q.shape
846+
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
847+
num_query_per_kv = num_query_heads // num_kv_heads
848+
849+
lengths = lengths
850+
page_indices = page_indices
851+
outputs = []
852+
for i in range(batch_size):
853+
kv_len = lengths[i]
854+
num_pages = (kv_len+page_size-1) // page_size
855+
indices = page_indices[i, :num_pages]
856+
857+
k = k_pages[:, indices]
858+
k = jnp.permute_dims(k, (1, 2, 0, 3))
859+
k = jnp.reshape(k, (num_pages * page_size, num_kv_heads, head_size))
860+
k = k[:kv_len]
861+
862+
v = v_pages[:, indices]
863+
v = jnp.permute_dims(v, (1, 2, 0, 3))
864+
v = jnp.reshape(v, (num_pages * page_size, num_kv_heads, head_size))
865+
v = v[:kv_len]
866+
867+
if num_query_per_kv != 1:
868+
k = jnp.repeat(k, num_query_per_kv, axis=1)
869+
v = jnp.repeat(v, num_query_per_kv, axis=1)
870+
871+
attn = jnp.einsum("qhd,khd->hqk", q[i], k)
872+
attn = attn.astype('float32')
873+
q_span = (kv_len-query_len) + jax.lax.broadcasted_iota(
874+
jnp.int32, (query_len, kv_len), 0
875+
)
876+
kv_span = jax.lax.broadcasted_iota(
877+
jnp.int32, (query_len, kv_len), 1
878+
)
879+
mask=jnp.where(q_span < kv_span, float("-inf"), 0.)
880+
with jax.numpy_rank_promotion("allow"):
881+
attn = attn + mask
882+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
883+
out = jnp.einsum("hqk,khd->qhd", attn, v)
884+
outputs.append(out)
885+
output = jnp.stack(outputs, axis=0)
886+
return output
887+
888+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
889+
"This test only works on TPUv4+.")
890+
def test_extended_paged_attention_jax_ref_impl(self):
891+
# Use multiple queries, verify my jax ref impl of paged attention against
892+
# Woosuk's pytorch ref impl. Should get the same result.
893+
from torch_xla.experimental.custom_kernel import ref_extended_paged_attention
894+
895+
batch_size: int = 3
896+
head_size: int = 128
897+
dtype_torch: torch.dtype = torch.float32
898+
max_kv_len: int = 1024
899+
total_num_pages: int = 32
900+
pages_per_sequence = total_num_pages
901+
902+
# num_compute_blks_q=1, num_compute_blks_kv=1,num_q_heads_per_kv_head=8
903+
# num_compute_blks_q=(query_len//num_queries_per_compute_block)
904+
# Change num_queries_per_compute_block to adjust num_compute_blks_q
905+
# num_compute_blks_kv=(pages_per_sequence//num_kv_pages_per_compute_block) where
906+
# pages_per_sequence is the same as total_num_pages
907+
# Change pallas_compute_block_size to adjust num_compute_blks_kv
908+
pallas_compute_block_size = 2048
909+
page_size: int = 64
910+
num_kv_pages_per_compute_block=pallas_compute_block_size // page_size
911+
query_len: int = 8
912+
num_queries_per_compute_block=8
913+
num_query_heads: int = 64
914+
num_kv_heads: int = 8
915+
916+
assert num_query_heads % num_kv_heads == 0
917+
assert query_len <= max_kv_len
918+
assert max_kv_len <= total_num_pages * page_size
919+
920+
print(f'The test test_extended_paged_attention_multiple_queries begins with {query_len=}')
921+
q = torch.randn(batch_size, query_len, num_query_heads, head_size, dtype=dtype_torch)
922+
k_pages = torch.randn(num_kv_heads, total_num_pages, page_size, head_size, dtype=dtype_torch)
923+
v_pages = torch.rand_like(k_pages)
924+
kv_seq_lengths = torch.randint(query_len, max_kv_len + 1, (batch_size,))
925+
page_indices = torch.randint(0, total_num_pages, (batch_size, total_num_pages))
926+
927+
# Run the jax ref impl with query_len>1
928+
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
929+
assert q_jax.shape==(batch_size, query_len, num_query_heads, head_size), f"Input q_jax has the wrong shape: {q_jax.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
930+
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
931+
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
932+
kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
933+
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
934+
print('xw32 calling jax_extended_paged_attention1')
935+
actual_output = self._ref_jax_extended_paged_attention(
936+
q_jax,
937+
k_pages_jax,
938+
v_pages_jax,
939+
kv_seq_lens_jax,
940+
page_indices_jax,
941+
)
942+
943+
944+
# Run Woosuk's non-kernel impl.
945+
assert q.shape==(batch_size, query_len, num_query_heads, head_size), f"Input ref_q_torch has the wrong shape: {q.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
946+
assert jnp.allclose(q_jax, jnp.array(q.numpy(), dtype=jnp.float32))
947+
assert jnp.allclose(k_pages_jax, jnp.array(k_pages.numpy(), dtype=jnp.float32))
948+
assert jnp.allclose(v_pages_jax, jnp.array(v_pages.numpy(), dtype=jnp.float32))
949+
assert jnp.allclose(kv_seq_lens_jax, jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32))
950+
assert jnp.allclose(page_indices_jax, jnp.array(page_indices.numpy(), dtype=jnp.int32))
951+
952+
expected_output = ref_extended_paged_attention(
953+
q,
954+
k_pages,
955+
v_pages,
956+
kv_seq_lengths,
957+
page_indices,
958+
)
959+
960+
expected_output_cpu=expected_output.cpu()
961+
actual_output_cpu=torch.from_numpy(np.array(actual_output))
826962
# print(f'{expected_output_cpu=}')
827963
# print(f'{actual_output_cpu=}')
828964
print(f'actual_output_cpu.shape={actual_output_cpu.shape}')
829965
print(f'expected_output_cpu.shape={expected_output_cpu.shape}')
830966
self.assertEqual(actual_output_cpu.shape, expected_output_cpu.shape)
831967
# torch.set_printoptions(profile="full")
832-
print(f'{(actual_output_cpu-expected_output_cpu).abs()}')
968+
# print(f'{(actual_output_cpu-expected_output_cpu).abs()}')
833969
print(f'Output max diff: {(expected_output_cpu - actual_output_cpu).abs().max().item()}')
834970
print(f'Output mean diff: {(expected_output_cpu - actual_output_cpu).abs().mean().item()}')
835971
self.assertTrue(
836972
torch.allclose(
837973
expected_output_cpu,
838974
actual_output_cpu,
839-
atol=1e-5,
840-
rtol=1e-5))
975+
atol=3e-2,
976+
rtol=2e-2))
841977

842978
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
843979
"This test only works on TPUv4+.")

0 commit comments

Comments
 (0)