14
14
import numpy as np
15
15
16
16
if xr .device_type () == 'TPU' :
17
- from torch_xla .experimental .custom_kernel import jax_import_guard
17
+ from torch_xla .experimental .custom_kernel import jax_import_guard , convert_torch_dtype_to_jax
18
18
jax_import_guard ()
19
19
import jax
20
20
import jax .numpy as jnp
@@ -98,7 +98,8 @@ def _ragged_pagedattention_generate_qkv(
98
98
head_dim ,
99
99
page_size ,
100
100
num_pages ,
101
- dtype ,
101
+ q_dtype ,
102
+ kv_dtype ,
102
103
* ,
103
104
max_num_batched_tokens = None ,
104
105
max_num_seqs = 16 ,
@@ -129,10 +130,11 @@ def _ragged_pagedattention_generate_qkv(
129
130
kv_lens = torch .nn .functional .pad (kv_lens ,
130
131
(0 , max_num_seqs - kv_lens .shape [0 ]),
131
132
"constant" , 0 )
133
+ # Use float32 for randn because it doesn't support some dtypes like float8
132
134
q = torch .randn ((max_num_batched_tokens , num_q_heads , head_dim ),
133
- dtype = dtype )
135
+ dtype = torch . float32 ). to ( q_dtype )
134
136
kv_pages = torch .randn ((num_pages , page_size , num_kv_heads * 2 , head_dim ),
135
- dtype = dtype )
137
+ dtype = torch . float32 ). to ( kv_dtype )
136
138
page_indices = torch .randint (
137
139
0 , num_pages , (max_num_seqs , pages_per_seq ), dtype = torch .int32 )
138
140
return q , kv_pages , kv_lens , page_indices , cu_q_lens
@@ -632,7 +634,8 @@ def _test_ragged_paged_attention(
632
634
head_dim ,
633
635
page_size ,
634
636
num_pages ,
635
- dtype ,
637
+ q_dtype ,
638
+ kv_dtype ,
636
639
* ,
637
640
sm_scale = 1.0 ,
638
641
sliding_window = None ,
@@ -654,9 +657,18 @@ def _test_ragged_paged_attention(
654
657
head_dim ,
655
658
page_size ,
656
659
num_pages ,
657
- dtype ,
660
+ q_dtype ,
661
+ kv_dtype ,
658
662
max_num_batched_tokens = max_num_batched_tokens ,
659
663
max_num_seqs = max_num_seqs )
664
+ k_scale = 0.5 if kv_dtype in [torch .float8_e5m2 ] else None
665
+ v_scale = 0.5 if kv_dtype in [torch .float8_e5m2 ] else None
666
+ num_kv_heads = num_heads [1 ]
667
+ if num_kv_heads == 1 and kv_dtype in [torch .float8_e5m2 ]:
668
+ self .skipTest (
669
+ "attention kernel cannot support because it is not XLA fully tiled" )
670
+ if kv_dtype is torch .float8_e5m2 and tpu .version () <= 4 :
671
+ self .skipTest ("TPU v4 or older doesn't support fp8" )
660
672
661
673
q_xla = q .to ("xla" )
662
674
kv_pages_xla = kv_pages .to ("xla" )
@@ -677,6 +689,8 @@ def ragged_paged_attention_wrapper(
677
689
sm_scale = sm_scale ,
678
690
sliding_window = sliding_window ,
679
691
soft_cap = soft_cap ,
692
+ k_scale = k_scale ,
693
+ v_scale = v_scale ,
680
694
use_kernel = True ,
681
695
num_kv_pages_per_block = num_kv_pages_per_block ,
682
696
num_queries_per_block = num_queries_per_block ,
@@ -691,6 +705,8 @@ def ragged_paged_attention_wrapper(
691
705
sm_scale = sm_scale ,
692
706
sliding_window = sliding_window ,
693
707
soft_cap = soft_cap ,
708
+ k_scale = k_scale ,
709
+ v_scale = v_scale ,
694
710
use_kernel = use_kernel ,
695
711
num_kv_pages_per_block = num_kv_pages_per_block ,
696
712
num_queries_per_block = num_queries_per_block ,
@@ -711,6 +727,8 @@ def ragged_paged_attention_wrapper(
711
727
sm_scale = sm_scale ,
712
728
sliding_window = sliding_window ,
713
729
soft_cap = soft_cap ,
730
+ k_scale = k_scale ,
731
+ v_scale = v_scale ,
714
732
use_kernel = True ,
715
733
num_kv_pages_per_block = num_kv_pages_per_block ,
716
734
num_queries_per_block = num_queries_per_block ,
@@ -726,6 +744,8 @@ def ragged_paged_attention_wrapper(
726
744
sm_scale = sm_scale ,
727
745
sliding_window = sliding_window ,
728
746
soft_cap = soft_cap ,
747
+ k_scale = k_scale ,
748
+ v_scale = v_scale ,
729
749
use_kernel = False ,
730
750
)
731
751
@@ -734,17 +754,14 @@ def ragged_paged_attention_wrapper(
734
754
self .assertEqual (kernel_output_cpu .shape , nonkernel_output_cpu .shape )
735
755
self .assertEqual (kernel_output_cpu .dtype , nonkernel_output_cpu .dtype )
736
756
737
- assert dtype == torch .float32 or dtype == torch .bfloat16
738
- jnp_dtype = jnp .float32
739
- tol = 0.15
740
- if dtype == torch .bfloat16 :
741
- jnp_dtype = jnp .bfloat16
742
- tol = 0.3
757
+ tol = 0.15 if q_dtype == torch .float32 else 0.3
758
+ q_jnp_dtype = convert_torch_dtype_to_jax (q_dtype )
759
+ kv_jnp_dtype = convert_torch_dtype_to_jax (kv_dtype )
743
760
744
761
# Numpy does not support bfloat16 directly. So we convert f32 first.
745
- q_jax = jnp .array (q .to (torch .float32 ).numpy (), dtype = jnp_dtype )
762
+ q_jax = jnp .array (q .to (torch .float32 ).numpy (), dtype = q_jnp_dtype )
746
763
kv_pages_jax = jnp .array (
747
- kv_pages .to (torch .float32 ).numpy (), dtype = jnp_dtype )
764
+ kv_pages .to (torch .float32 ).numpy (), dtype = kv_jnp_dtype )
748
765
kv_lens_jax = jnp .array (kv_lens .numpy (), dtype = jnp .int32 )
749
766
page_indices_jax = jnp .array (page_indices .numpy (), dtype = jnp .int32 )
750
767
cu_q_lens_jax = jnp .array (cu_q_lens .numpy (), dtype = jnp .int32 )
@@ -765,7 +782,9 @@ def ragged_paged_attention_wrapper(
765
782
sm_scale = sm_scale ,
766
783
sliding_window = sliding_window ,
767
784
soft_cap = soft_cap ,
768
- )[:cu_q_lens [num_seqs ]].astype (jnp .float32 ))).to (dtype )
785
+ k_scale = k_scale ,
786
+ v_scale = v_scale ,
787
+ )[:cu_q_lens [num_seqs ]].astype (jnp .float32 ))).to (q_dtype )
769
788
jax_kernel_output_cpu = jax_kernel_output .cpu ()
770
789
771
790
torch .testing .assert_close (
@@ -776,7 +795,8 @@ def ragged_paged_attention_wrapper(
776
795
@parameterized .product (
777
796
seq_lens = [[(1 , 1328 ), (5 , 18 ), (500 , 563 )]],
778
797
num_heads = [(32 , 8 ), (8 , 1 )],
779
- dtype = [torch .float32 , torch .bfloat16 ],
798
+ dtype = [(torch .bfloat16 , torch .bfloat16 ),
799
+ (torch .bfloat16 , torch .float8_e5m2 )],
780
800
sm_scale = [1.0 , 0.5 ],
781
801
sliding_window = [None , 128 ],
782
802
soft_cap = [None , 10.0 ],
@@ -796,14 +816,16 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
796
816
head_dim = 128
797
817
page_size = 16
798
818
num_pages = 1000
819
+ q_dtype , kv_dtype = dtype
799
820
800
821
self ._test_ragged_paged_attention (
801
822
seq_lens ,
802
823
num_heads ,
803
824
head_dim ,
804
825
page_size ,
805
826
num_pages ,
806
- dtype ,
827
+ q_dtype ,
828
+ kv_dtype ,
807
829
sm_scale = sm_scale ,
808
830
sliding_window = sliding_window ,
809
831
soft_cap = soft_cap ,
@@ -814,7 +836,8 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
814
836
@parameterized .product (
815
837
seq_lens = [[(1 , 1328 ), (5 , 18 ), (500 , 563 )]],
816
838
num_heads = [(32 , 8 ), (8 , 1 )],
817
- dtype = [torch .float32 , torch .bfloat16 ],
839
+ dtype = [(torch .bfloat16 , torch .bfloat16 ),
840
+ (torch .bfloat16 , torch .float8_e5m2 )],
818
841
sm_scale = [1.0 , 0.5 ],
819
842
sliding_window = [None , 128 ],
820
843
soft_cap = [None , 10.0 ],
@@ -835,14 +858,16 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
835
858
head_dim = 128
836
859
page_size = 16
837
860
num_pages = 1000
861
+ q_dtype , kv_dtype = dtype
838
862
839
863
self ._test_ragged_paged_attention (
840
864
seq_lens ,
841
865
num_heads ,
842
866
head_dim ,
843
867
page_size ,
844
868
num_pages ,
845
- dtype ,
869
+ q_dtype ,
870
+ kv_dtype ,
846
871
sm_scale = sm_scale ,
847
872
sliding_window = sliding_window ,
848
873
soft_cap = soft_cap ,
0 commit comments