@@ -800,44 +800,180 @@ def test_extended_paged_attention_v1_multiple_queries(self):
800
800
print ('my new extended paged attention finished yay' )
801
801
802
802
# Run Woosuk's non-kernel impl.
803
- ref_q_torch = q .detach ().clone ()
804
- assert ref_q_torch .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { ref_q_torch .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
805
- assert jnp .allclose (q_jax , jnp .array (ref_q_torch .numpy (), dtype = jnp .float32 ))
806
- ref_k_pages_torch = k_pages .detach ().clone ()
807
- assert jnp .allclose (k_pages_jax , jnp .array (ref_k_pages_torch .numpy (), dtype = jnp .float32 ))
808
- ref_v_pages_torch = v_pages .detach ().clone ()
809
- assert jnp .allclose (v_pages_jax , jnp .array (ref_v_pages_torch .numpy (), dtype = jnp .float32 ))
810
- ref_kv_seq_lens_torch = kv_seq_lengths .detach ().clone ()
811
- assert jnp .allclose (kv_seq_lens_jax , jnp .array (ref_kv_seq_lens_torch .numpy (), dtype = jnp .int32 ))
812
- ref_page_indices_torch = page_indices .detach ().clone ()
813
- assert jnp .allclose (page_indices_jax , jnp .array (ref_page_indices_torch .numpy (), dtype = jnp .int32 ))
803
+ assert q .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { q .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
804
+ assert jnp .allclose (q_jax , jnp .array (q .numpy (), dtype = jnp .float32 ))
805
+ assert jnp .allclose (k_pages_jax , jnp .array (k_pages .numpy (), dtype = jnp .float32 ))
806
+ assert jnp .allclose (v_pages_jax , jnp .array (v_pages .numpy (), dtype = jnp .float32 ))
807
+ assert jnp .allclose (kv_seq_lens_jax , jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 ))
808
+ assert jnp .allclose (page_indices_jax , jnp .array (page_indices .numpy (), dtype = jnp .int32 ))
814
809
815
810
expected_output = ref_extended_paged_attention (
816
- ref_q_torch ,
817
- ref_k_pages_torch ,
818
- ref_v_pages_torch ,
819
- ref_kv_seq_lens_torch ,
820
- ref_page_indices_torch ,
811
+ q ,
812
+ k_pages ,
813
+ v_pages ,
814
+ kv_seq_lengths ,
815
+ page_indices ,
821
816
)
822
817
823
818
expected_output_cpu = expected_output .cpu ()
824
819
# Need to squeeze out the query_len dimension!
825
820
actual_output_cpu = actual_output .cpu ()
821
+ print (f'{ expected_output_cpu = } ' )
822
+ print (f'{ actual_output_cpu = } ' )
823
+ print (f'actual_output_cpu.shape={ actual_output_cpu .shape } ' )
824
+ print (f'expected_output_cpu.shape={ expected_output_cpu .shape } ' )
825
+ self .assertEqual (actual_output_cpu .shape , expected_output_cpu .shape )
826
+ # torch.set_printoptions(profile="full")
827
+ print (f'{ (actual_output_cpu - expected_output_cpu ).abs ()} ' )
828
+ print (f'Output max diff: { (expected_output_cpu - actual_output_cpu ).abs ().max ().item ()} ' )
829
+ print (f'Output mean diff: { (expected_output_cpu - actual_output_cpu ).abs ().mean ().item ()} ' )
830
+ self .assertTrue (
831
+ torch .allclose (
832
+ expected_output_cpu ,
833
+ actual_output_cpu ,
834
+ atol = 1e-2 ,
835
+ rtol = 1e-2 ))
836
+
837
+ def _ref_jax_extended_paged_attention (
838
+ self ,
839
+ q , # [batch_size, query_len, num_query_heads, head_size]
840
+ k_pages ,# [num_kv_heads, total_num_pages, page_size, head_size]
841
+ v_pages ,# [num_kv_heads, total_num_pages, page_size, head_size]
842
+ lengths ,# [batch_size]
843
+ page_indices ,# [batch_size, pages_per_sequence]
844
+ ):
845
+ batch_size , query_len , num_query_heads , head_size = q .shape
846
+ num_kv_heads , total_num_pages , page_size , _ = k_pages .shape
847
+ num_query_per_kv = num_query_heads // num_kv_heads
848
+
849
+ lengths = lengths
850
+ page_indices = page_indices
851
+ outputs = []
852
+ for i in range (batch_size ):
853
+ kv_len = lengths [i ]
854
+ num_pages = (kv_len + page_size - 1 ) // page_size
855
+ indices = page_indices [i , :num_pages ]
856
+
857
+ k = k_pages [:, indices ]
858
+ k = jnp .permute_dims (k , (1 , 2 , 0 , 3 ))
859
+ k = jnp .reshape (k , (num_pages * page_size , num_kv_heads , head_size ))
860
+ k = k [:kv_len ]
861
+
862
+ v = v_pages [:, indices ]
863
+ v = jnp .permute_dims (v , (1 , 2 , 0 , 3 ))
864
+ v = jnp .reshape (v , (num_pages * page_size , num_kv_heads , head_size ))
865
+ v = v [:kv_len ]
866
+
867
+ if num_query_per_kv != 1 :
868
+ k = jnp .repeat (k , num_query_per_kv , axis = 1 )
869
+ v = jnp .repeat (v , num_query_per_kv , axis = 1 )
870
+
871
+ attn = jnp .einsum ("qhd,khd->hqk" , q [i ], k )
872
+ attn = attn .astype ('float32' )
873
+ q_span = (kv_len - query_len ) + jax .lax .broadcasted_iota (
874
+ jnp .int32 , (query_len , kv_len ), 0
875
+ )
876
+ kv_span = jax .lax .broadcasted_iota (
877
+ jnp .int32 , (query_len , kv_len ), 1
878
+ )
879
+ mask = jnp .where (q_span < kv_span , float ("-inf" ), 0. )
880
+ with jax .numpy_rank_promotion ("allow" ):
881
+ attn = attn + mask
882
+ attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
883
+ out = jnp .einsum ("hqk,khd->qhd" , attn , v )
884
+ outputs .append (out )
885
+ output = jnp .stack (outputs , axis = 0 )
886
+ return output
887
+
888
+ @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
889
+ "This test only works on TPUv4+." )
890
+ def test_extended_paged_attention_jax_ref_impl (self ):
891
+ # Use multiple queries, verify my jax ref impl of paged attention against
892
+ # Woosuk's pytorch ref impl. Should get the same result.
893
+ from torch_xla .experimental .custom_kernel import ref_extended_paged_attention
894
+
895
+ batch_size : int = 3
896
+ head_size : int = 128
897
+ dtype_torch : torch .dtype = torch .float32
898
+ max_kv_len : int = 1024
899
+ total_num_pages : int = 32
900
+ pages_per_sequence = total_num_pages
901
+
902
+ # num_compute_blks_q=1, num_compute_blks_kv=1,num_q_heads_per_kv_head=8
903
+ # num_compute_blks_q=(query_len//num_queries_per_compute_block)
904
+ # Change num_queries_per_compute_block to adjust num_compute_blks_q
905
+ # num_compute_blks_kv=(pages_per_sequence//num_kv_pages_per_compute_block) where
906
+ # pages_per_sequence is the same as total_num_pages
907
+ # Change pallas_compute_block_size to adjust num_compute_blks_kv
908
+ pallas_compute_block_size = 2048
909
+ page_size : int = 64
910
+ num_kv_pages_per_compute_block = pallas_compute_block_size // page_size
911
+ query_len : int = 8
912
+ num_queries_per_compute_block = 8
913
+ num_query_heads : int = 64
914
+ num_kv_heads : int = 8
915
+
916
+ assert num_query_heads % num_kv_heads == 0
917
+ assert query_len <= max_kv_len
918
+ assert max_kv_len <= total_num_pages * page_size
919
+
920
+ print (f'The test test_extended_paged_attention_multiple_queries begins with { query_len = } ' )
921
+ q = torch .randn (batch_size , query_len , num_query_heads , head_size , dtype = dtype_torch )
922
+ k_pages = torch .randn (num_kv_heads , total_num_pages , page_size , head_size , dtype = dtype_torch )
923
+ v_pages = torch .rand_like (k_pages )
924
+ kv_seq_lengths = torch .randint (query_len , max_kv_len + 1 , (batch_size ,))
925
+ page_indices = torch .randint (0 , total_num_pages , (batch_size , total_num_pages ))
926
+
927
+ # Run the jax ref impl with query_len>1
928
+ q_jax = jnp .array (q .numpy (), dtype = jnp .float32 )
929
+ assert q_jax .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input q_jax has the wrong shape: { q_jax .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
930
+ k_pages_jax = jnp .array (k_pages .numpy (), dtype = jnp .float32 )
931
+ v_pages_jax = jnp .array (v_pages .numpy (), dtype = jnp .float32 )
932
+ kv_seq_lens_jax = jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 )
933
+ page_indices_jax = jnp .array (page_indices .numpy (), dtype = jnp .int32 )
934
+ print ('xw32 calling jax_extended_paged_attention1' )
935
+ actual_output = self ._ref_jax_extended_paged_attention (
936
+ q_jax ,
937
+ k_pages_jax ,
938
+ v_pages_jax ,
939
+ kv_seq_lens_jax ,
940
+ page_indices_jax ,
941
+ )
942
+
943
+
944
+ # Run Woosuk's non-kernel impl.
945
+ assert q .shape == (batch_size , query_len , num_query_heads , head_size ), f"Input ref_q_torch has the wrong shape: { q .shape } . Expect { (batch_size , query_len , num_query_heads , head_size )} ."
946
+ assert jnp .allclose (q_jax , jnp .array (q .numpy (), dtype = jnp .float32 ))
947
+ assert jnp .allclose (k_pages_jax , jnp .array (k_pages .numpy (), dtype = jnp .float32 ))
948
+ assert jnp .allclose (v_pages_jax , jnp .array (v_pages .numpy (), dtype = jnp .float32 ))
949
+ assert jnp .allclose (kv_seq_lens_jax , jnp .array (kv_seq_lengths .numpy (), dtype = jnp .int32 ))
950
+ assert jnp .allclose (page_indices_jax , jnp .array (page_indices .numpy (), dtype = jnp .int32 ))
951
+
952
+ expected_output = ref_extended_paged_attention (
953
+ q ,
954
+ k_pages ,
955
+ v_pages ,
956
+ kv_seq_lengths ,
957
+ page_indices ,
958
+ )
959
+
960
+ expected_output_cpu = expected_output .cpu ()
961
+ actual_output_cpu = torch .from_numpy (np .array (actual_output ))
826
962
# print(f'{expected_output_cpu=}')
827
963
# print(f'{actual_output_cpu=}')
828
964
print (f'actual_output_cpu.shape={ actual_output_cpu .shape } ' )
829
965
print (f'expected_output_cpu.shape={ expected_output_cpu .shape } ' )
830
966
self .assertEqual (actual_output_cpu .shape , expected_output_cpu .shape )
831
967
# torch.set_printoptions(profile="full")
832
- print (f'{ (actual_output_cpu - expected_output_cpu ).abs ()} ' )
968
+ # print(f'{(actual_output_cpu-expected_output_cpu).abs()}')
833
969
print (f'Output max diff: { (expected_output_cpu - actual_output_cpu ).abs ().max ().item ()} ' )
834
970
print (f'Output mean diff: { (expected_output_cpu - actual_output_cpu ).abs ().mean ().item ()} ' )
835
971
self .assertTrue (
836
972
torch .allclose (
837
973
expected_output_cpu ,
838
974
actual_output_cpu ,
839
- atol = 1e-5 ,
840
- rtol = 1e-5 ))
975
+ atol = 3e-2 ,
976
+ rtol = 2e-2 ))
841
977
842
978
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
843
979
"This test only works on TPUv4+." )
0 commit comments