Skip to content

Commit 070138a

Browse files
committed
fix test
1 parent 5749e4c commit 070138a

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

test/test_pallas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,10 @@ def _test_ragged_paged_attention(
663663
max_num_seqs=max_num_seqs)
664664
k_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
665665
v_scale = 0.5 if kv_dtype in [torch.float8_e5m2] else None
666+
if num_heads[1] == 1 and kv_dtype in [torch.float8_e5m2]:
667+
self.skipTest("attention kernel cannot support ")
668+
if kv_dtype is torch.float8_e5m2 and tpu.version() <= 4:
669+
self.skipTest("TPU v4 or older doesn't support fp8")
666670

667671
q_xla = q.to("xla")
668672
kv_pages_xla = kv_pages.to("xla")
@@ -778,6 +782,8 @@ def ragged_paged_attention_wrapper(
778782
sm_scale=sm_scale,
779783
sliding_window=sliding_window,
780784
soft_cap=soft_cap,
785+
k_scale=k_scale,
786+
v_scale=v_scale,
781787
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(q_dtype)
782788
jax_kernel_output_cpu = jax_kernel_output.cpu()
783789

torch_xla/experimental/custom_kernel.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
183183
return jnp.int8
184184
elif dtype == torch.uint8:
185185
return jnp.uint8
186+
elif dtype == torch.float8_e5m2:
187+
return jnp.float8_e5m2
188+
elif dtype == torch.float8_e4m3fn:
189+
return jnp.float8_e4m3fn
190+
elif dtype == torch.float8_e4m3fnuz:
191+
return jnp.float8_e4m3fnuz
186192
else:
187193
raise ValueError(f"Unsupported dtype: {dtype}")
188194

@@ -930,11 +936,11 @@ def _ragged_paged_attention_nonkernel(
930936
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
931937
head_dim)[:kv_len]
932938
if k_scale is not None:
933-
k = k.astype(torch.float32) * k_scale
934-
k = k.astype(q.dtype)
939+
k = k.to(torch.float32) * k_scale
940+
k = k.to(q.dtype)
935941
if v_scale is not None:
936-
v = v.astype(torch.float32) * v_scale
937-
v = v.astype(q.dtype)
942+
v = v.to(torch.float32) * v_scale
943+
v = v.to(q.dtype)
938944
k = torch.repeat_interleave(k, num_query_per_kv, dim=1)
939945
v = torch.repeat_interleave(v, num_query_per_kv, dim=1)
940946
attn = torch.einsum("qhd,khd->hqk", q, k)

0 commit comments

Comments
 (0)