@@ -800,31 +800,26 @@ def test_extended_paged_attention_v1_multiple_queries(self):
800
800
print ('my new extended paged attention finished yay' )
801
801
802
802
# Run Woosuk's non-kernel impl.
803
- ref_q_torch = q .detach ().clone ()
804
- assert ref_q_torch .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { ref_q_torch .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
805
- assert jnp .allclose (q_jax , jnp .array (ref_q_torch .numpy (), dtype = jnp .float32 ))
806
- ref_k_pages_torch = k_pages .detach ().clone ()
807
- assert jnp .allclose (k_pages_jax , jnp .array (ref_k_pages_torch .numpy (), dtype = jnp .float32 ))
808
- ref_v_pages_torch = v_pages .detach ().clone ()
809
- assert jnp .allclose (v_pages_jax , jnp .array (ref_v_pages_torch .numpy (), dtype = jnp .float32 ))
810
- ref_kv_seq_lens_torch = kv_seq_lengths .detach ().clone ()
811
- assert jnp .allclose (kv_seq_lens_jax , jnp .array (ref_kv_seq_lens_torch .numpy (), dtype = jnp .int32 ))
812
- ref_page_indices_torch = page_indices .detach ().clone ()
813
- assert jnp .allclose (page_indices_jax , jnp .array (ref_page_indices_torch .numpy (), dtype = jnp .int32 ))
803
+ assert q .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { q .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
804
+ assert jnp .allclose (q_jax , jnp .array (q .numpy (), dtype = jnp .float32 ))
805
+ assert jnp .allclose (k_pages_jax , jnp .array (k_pages .numpy (), dtype = jnp .float32 ))
806
+ assert jnp .allclose (v_pages_jax , jnp .array (v_pages .numpy (), dtype = jnp .float32 ))
807
+ assert jnp .allclose (kv_seq_lens_jax , jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 ))
808
+ assert jnp .allclose (page_indices_jax , jnp .array (page_indices .numpy (), dtype = jnp .int32 ))
814
809
815
810
expected_output = ref_extended_paged_attention (
816
- ref_q_torch ,
817
- ref_k_pages_torch ,
818
- ref_v_pages_torch ,
819
- ref_kv_seq_lens_torch ,
820
- ref_page_indices_torch ,
811
+ q ,
812
+ k_pages ,
813
+ v_pages ,
814
+ kv_seq_lengths ,
815
+ page_indices ,
821
816
)
822
817
823
818
expected_output_cpu = expected_output .cpu ()
824
819
# Need to squeeze out the query_len dimension!
825
820
actual_output_cpu = actual_output .cpu ()
826
- # print(f'{expected_output_cpu=}')
827
- # print(f'{actual_output_cpu=}')
821
+ print (f'{ expected_output_cpu = } ' )
822
+ print (f'{ actual_output_cpu = } ' )
828
823
print (f'actual_output_cpu.shape={ actual_output_cpu .shape } ' )
829
824
print (f'expected_output_cpu.shape={ expected_output_cpu .shape } ' )
830
825
self .assertEqual (actual_output_cpu .shape , expected_output_cpu .shape )
@@ -836,8 +831,148 @@ def test_extended_paged_attention_v1_multiple_queries(self):
836
831
torch .allclose (
837
832
expected_output_cpu ,
838
833
actual_output_cpu ,
839
- atol = 1e-5 ,
840
- rtol = 1e-5 ))
834
+ atol = 1e-2 ,
835
+ rtol = 1e-2 ))
836
+
837
+ def ref_jax_extended_paged_attention (
838
+ self ,
839
+ q , # [batch_size, query_len, num_query_heads, head_size]
840
+ k_pages ,# [num_kv_heads, total_num_pages, page_size, head_size]
841
+ v_pages ,# [num_kv_heads, total_num_pages, page_size, head_size]
842
+ lengths ,# [batch_size]
843
+ page_indices ,# [batch_size, pages_per_sequence]
844
+ ):
845
+ batch_size , query_len , num_query_heads , head_size = q .shape
846
+ num_kv_heads , total_num_pages , page_size , _ = k_pages .shape
847
+ num_query_per_kv = num_query_heads // num_kv_heads
848
+
849
+ lengths = lengths
850
+ page_indices = page_indices
851
+ outputs = []
852
+ for i in range (batch_size ):
853
+ kv_len = lengths [i ]
854
+ num_pages = (kv_len + page_size - 1 ) // page_size
855
+ indices = page_indices [i , :num_pages ]
856
+
857
+ k = k_pages [:, indices ]
858
+ k = jnp .permute_dims (k , (1 , 2 , 0 , 3 ))
859
+ k = jnp .reshape (k , (num_pages * page_size , num_kv_heads , head_size ))
860
+ k = k [:kv_len ]
861
+
862
+ v = v_pages [:, indices ]
863
+ v = jnp .permute_dims (v , (1 , 2 , 0 , 3 ))
864
+ v = jnp .reshape (v , (num_pages * page_size , num_kv_heads , head_size ))
865
+ v = v [:kv_len ]
866
+
867
+ if num_query_per_kv != 1 :
868
+ k = jnp .repeat (k , num_query_per_kv , axis = 1 )
869
+ v = jnp .repeat (v , num_query_per_kv , axis = 1 )
870
+
871
+ attn = jnp .einsum ("qhd,khd->hqk" , q [i ], k )
872
+ attn = attn .astype ('float32' )
873
+ q_span = (kv_len - query_len ) + jax .lax .broadcasted_iota (
874
+ jnp .int32 , (query_len , kv_len ), 0
875
+ )
876
+ kv_span = jax .lax .broadcasted_iota (
877
+ jnp .int32 , (query_len , kv_len ), 1
878
+ )
879
+ mask = jnp .where (q_span < kv_span , float ("-inf" ), 0. )
880
+ attn = attn + mask
881
+ attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
882
+ out = jnp .einsum ("hqk,khd->qhd" , attn , v )
883
+ outputs .append (out )
884
+ output = jnp .stack (outputs , axis = 0 )
885
+ return output
886
+
887
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
888
+ "This test only works on TPUv4+." )
889
+ def test_extended_paged_attention_jax_ref_impl (self ):
890
+ # Use multiple queries, verify my jax ref impl of paged attention against
891
+ # Woosuk's pytorch ref impl. Should get the same result.
892
+ from torch_xla .experimental .custom_kernel import ref_extended_paged_attention
893
+
894
+ batch_size : int = 3
895
+ head_size : int = 128
896
+ dtype_torch : torch .dtype = torch .float32
897
+ max_kv_len : int = 1024
898
+ total_num_pages : int = 32
899
+ pages_per_sequence = total_num_pages
900
+
901
+ # num_compute_blks_q=1, num_compute_blks_kv=1,num_q_heads_per_kv_head=8
902
+ # num_compute_blks_q=(query_len//num_queries_per_compute_block)
903
+ # Change num_queries_per_compute_block to adjust num_compute_blks_q
904
+ # num_compute_blks_kv=(pages_per_sequence//num_kv_pages_per_compute_block) where
905
+ # pages_per_sequence is the same as total_num_pages
906
+ # Change pallas_compute_block_size to adjust num_compute_blks_kv
907
+ pallas_compute_block_size = 2048
908
+ page_size : int = 64
909
+ num_kv_pages_per_compute_block = pallas_compute_block_size // page_size
910
+ query_len : int = 8
911
+ num_queries_per_compute_block = 8
912
+ num_query_heads : int = 64
913
+ num_kv_heads : int = 8
914
+
915
+ assert num_query_heads % num_kv_heads == 0
916
+ assert query_len <= max_kv_len
917
+ assert max_kv_len <= total_num_pages * page_size
918
+
919
+ print (f'The test test_extended_paged_attention_multiple_queries begins with { query_len = } ' )
920
+ q = torch .randn (batch_size , query_len , num_query_heads , head_size , dtype = dtype_torch )
921
+ k_pages = torch .randn (num_kv_heads , total_num_pages , page_size , head_size , dtype = dtype_torch )
922
+ v_pages = torch .rand_like (k_pages )
923
+ kv_seq_lengths = torch .randint (query_len , max_kv_len + 1 , (batch_size ,))
924
+ page_indices = torch .randint (0 , total_num_pages , (batch_size , total_num_pages ))
925
+
926
+ # Run the jax ref impl with query_len>1
927
+ q_jax = jnp .array (q .numpy (), dtype = jnp .float32 )
928
+ assert q_jax .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input q_jax has the wrong shape: { q_jax .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
929
+ k_pages_jax = jnp .array (k_pages .numpy (), dtype = jnp .float32 )
930
+ v_pages_jax = jnp .array (v_pages .numpy (), dtype = jnp .float32 )
931
+ kv_seq_lens_jax = jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 )
932
+ page_indices_jax = jnp .array (page_indices .numpy (), dtype = jnp .int32 )
933
+ print ('xw32 calling jax_extended_paged_attention1' )
934
+ actual_output = self .ref_jax_extended_paged_attention (
935
+ q_jax ,
936
+ k_pages_jax ,
937
+ v_pages_jax ,
938
+ kv_seq_lens_jax ,
939
+ page_indices_jax ,
940
+ )
941
+
942
+
943
+ # Run Woosuk's non-kernel impl.
944
+ assert q .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { q .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
945
+ assert jnp .allclose (q_jax , jnp .array (q .numpy (), dtype = jnp .float32 ))
946
+ assert jnp .allclose (k_pages_jax , jnp .array (k_pages .numpy (), dtype = jnp .float32 ))
947
+ assert jnp .allclose (v_pages_jax , jnp .array (v_pages .numpy (), dtype = jnp .float32 ))
948
+ assert jnp .allclose (kv_seq_lens_jax , jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 ))
949
+ assert jnp .allclose (page_indices_jax , jnp .array (page_indices .numpy (), dtype = jnp .int32 ))
950
+
951
+ expected_output = ref_extended_paged_attention (
952
+ q ,
953
+ k_pages ,
954
+ v_pages ,
955
+ kv_seq_lengths ,
956
+ page_indices ,
957
+ )
958
+
959
+ expected_output_cpu = expected_output .cpu ()
960
+ actual_output_cpu = torch .from_numpy (np .array (actual_output ))
961
+ print (f'{ expected_output_cpu = } ' )
962
+ print (f'{ actual_output_cpu = } ' )
963
+ print (f'actual_output_cpu.shape={ actual_output_cpu .shape } ' )
964
+ print (f'expected_output_cpu.shape={ expected_output_cpu .shape } ' )
965
+ self .assertEqual (actual_output_cpu .shape , expected_output_cpu .shape )
966
+ # torch.set_printoptions(profile="full")
967
+ print (f'{ (actual_output_cpu - expected_output_cpu ).abs ()} ' )
968
+ print (f'Output max diff: { (expected_output_cpu - actual_output_cpu ).abs ().max ().item ()} ' )
969
+ print (f'Output mean diff: { (expected_output_cpu - actual_output_cpu ).abs ().mean ().item ()} ' )
970
+ self .assertTrue (
971
+ torch .allclose (
972
+ expected_output_cpu ,
973
+ actual_output_cpu ,
974
+ atol = 3e-2 ,
975
+ rtol = 2e-2 ))
841
976
842
977
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
843
978
"This test only works on TPUv4+." )
0 commit comments