@@ -660,6 +660,8 @@ def _test_ragged_paged_attention(
660
660
kv_dtype ,
661
661
max_num_batched_tokens = max_num_batched_tokens ,
662
662
max_num_seqs = max_num_seqs )
663
+ k_scale = 0.5 if kv_dtype in [torch .float8_e5m2 ] else None
664
+ v_scale = 0.5 if kv_dtype in [torch .float8_e5m2 ] else None
663
665
664
666
q_xla = q .to ("xla" )
665
667
kv_pages_xla = kv_pages .to ("xla" )
@@ -680,6 +682,8 @@ def ragged_paged_attention_wrapper(
680
682
sm_scale = sm_scale ,
681
683
sliding_window = sliding_window ,
682
684
soft_cap = soft_cap ,
685
+ k_scale = k_scale ,
686
+ v_scale = v_scale ,
683
687
use_kernel = True ,
684
688
num_kv_pages_per_block = num_kv_pages_per_block ,
685
689
num_queries_per_block = num_queries_per_block ,
@@ -694,6 +698,8 @@ def ragged_paged_attention_wrapper(
694
698
sm_scale = sm_scale ,
695
699
sliding_window = sliding_window ,
696
700
soft_cap = soft_cap ,
701
+ k_scale = k_scale ,
702
+ v_scale = v_scale ,
697
703
use_kernel = use_kernel ,
698
704
num_kv_pages_per_block = num_kv_pages_per_block ,
699
705
num_queries_per_block = num_queries_per_block ,
@@ -714,6 +720,8 @@ def ragged_paged_attention_wrapper(
714
720
sm_scale = sm_scale ,
715
721
sliding_window = sliding_window ,
716
722
soft_cap = soft_cap ,
723
+ k_scale = k_scale ,
724
+ v_scale = v_scale ,
717
725
use_kernel = True ,
718
726
num_kv_pages_per_block = num_kv_pages_per_block ,
719
727
num_queries_per_block = num_queries_per_block ,
@@ -729,6 +737,8 @@ def ragged_paged_attention_wrapper(
729
737
sm_scale = sm_scale ,
730
738
sliding_window = sliding_window ,
731
739
soft_cap = soft_cap ,
740
+ k_scale = k_scale ,
741
+ v_scale = v_scale ,
732
742
use_kernel = False ,
733
743
)
734
744
@@ -737,10 +747,9 @@ def ragged_paged_attention_wrapper(
737
747
self .assertEqual (kernel_output_cpu .shape , nonkernel_output_cpu .shape )
738
748
self .assertEqual (kernel_output_cpu .dtype , nonkernel_output_cpu .dtype )
739
749
740
- assert dtype == torch .float32 or dtype == torch .bfloat16
741
750
jnp_dtype = jnp .float32
742
751
tol = 0.15
743
- if dtype == torch .bfloat16 :
752
+ if q_dtype == torch .bfloat16 :
744
753
jnp_dtype = jnp .bfloat16
745
754
tol = 0.3
746
755
@@ -768,7 +777,7 @@ def ragged_paged_attention_wrapper(
768
777
sm_scale = sm_scale ,
769
778
sliding_window = sliding_window ,
770
779
soft_cap = soft_cap ,
771
- )[:cu_q_lens [num_seqs ]].astype (jnp .float32 ))).to (dtype )
780
+ )[:cu_q_lens [num_seqs ]].astype (jnp .float32 ))).to (q_dtype )
772
781
jax_kernel_output_cpu = jax_kernel_output .cpu ()
773
782
774
783
torch .testing .assert_close (
0 commit comments