Skip to content

Commit 057c40b

Browse files
committed
fix test
1 parent d3655f8 commit 057c40b

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

test/test_pallas.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,8 @@ def _test_ragged_paged_attention(
660660
kv_dtype,
661661
max_num_batched_tokens=max_num_batched_tokens,
662662
max_num_seqs=max_num_seqs)
663+
k_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
664+
v_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
663665

664666
q_xla = q.to("xla")
665667
kv_pages_xla = kv_pages.to("xla")
@@ -680,6 +682,8 @@ def ragged_paged_attention_wrapper(
680682
sm_scale=sm_scale,
681683
sliding_window=sliding_window,
682684
soft_cap=soft_cap,
685+
k_scale=k_scale,
686+
v_scale=v_scale,
683687
use_kernel=True,
684688
num_kv_pages_per_block=num_kv_pages_per_block,
685689
num_queries_per_block=num_queries_per_block,
@@ -694,6 +698,8 @@ def ragged_paged_attention_wrapper(
694698
sm_scale=sm_scale,
695699
sliding_window=sliding_window,
696700
soft_cap=soft_cap,
701+
k_scale=k_scale,
702+
v_scale=v_scale,
697703
use_kernel=use_kernel,
698704
num_kv_pages_per_block=num_kv_pages_per_block,
699705
num_queries_per_block=num_queries_per_block,
@@ -714,6 +720,8 @@ def ragged_paged_attention_wrapper(
714720
sm_scale=sm_scale,
715721
sliding_window=sliding_window,
716722
soft_cap=soft_cap,
723+
k_scale=k_scale,
724+
v_scale=v_scale,
717725
use_kernel=True,
718726
num_kv_pages_per_block=num_kv_pages_per_block,
719727
num_queries_per_block=num_queries_per_block,
@@ -729,6 +737,8 @@ def ragged_paged_attention_wrapper(
729737
sm_scale=sm_scale,
730738
sliding_window=sliding_window,
731739
soft_cap=soft_cap,
740+
k_scale=k_scale,
741+
v_scale=v_scale,
732742
use_kernel=False,
733743
)
734744

@@ -737,10 +747,9 @@ def ragged_paged_attention_wrapper(
737747
self.assertEqual(kernel_output_cpu.shape, nonkernel_output_cpu.shape)
738748
self.assertEqual(kernel_output_cpu.dtype, nonkernel_output_cpu.dtype)
739749

740-
assert dtype == torch.float32 or dtype == torch.bfloat16
741750
jnp_dtype = jnp.float32
742751
tol = 0.15
743-
if dtype == torch.bfloat16:
752+
if q_dtype == torch.bfloat16:
744753
jnp_dtype = jnp.bfloat16
745754
tol = 0.3
746755

@@ -768,7 +777,7 @@ def ragged_paged_attention_wrapper(
768777
sm_scale=sm_scale,
769778
sliding_window=sliding_window,
770779
soft_cap=soft_cap,
771-
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(dtype)
780+
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(q_dtype)
772781
jax_kernel_output_cpu = jax_kernel_output.cpu()
773782

774783
torch.testing.assert_close(

0 commit comments

Comments
 (0)