Skip to content

Commit 42e54cb

Browse files
[AMD] Skip fp8 data type tests on RDNA3 for test_conversions (#9895)
Navi3 does not support fp8 data type. This PR skip all the fp8 tests in test_conversions for RDNA3. --------- Co-authored-by: root <saeidrostami.github@gmail.com>
1 parent 453d1c9 commit 42e54cb

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

python/test/unit/language/test_conversions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_hip_rdna4
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_hip_rdna3, is_hip_rdna4
11+
12+
FP8_DTYPES = ('float8e5', 'float8e4b15', 'float8e4nv', 'float8e4b8', 'float8e5b16')
1113

1214

1315
def matching_int(dtype):
@@ -283,6 +285,8 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
283285
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
284286
return
285287
elif is_hip():
288+
if src_dtype in FP8_DTYPES and is_hip_rdna3():
289+
pytest.skip(f"{src_dtype} is not supported on AMDGPU RDNA3")
286290
if (src_dtype == 'float8e4nv' and not (is_hip_cdna3() or is_hip_cdna4())):
287291
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
288292
if src_dtype == 'float8e4b15':
@@ -343,6 +347,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
343347
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
344348

345349
if is_hip():
350+
if dst_dtype in FP8_DTYPES and is_hip_rdna3():
351+
pytest.skip(f"{dst_dtype} is not supported on AMDGPU RDNA3")
346352
if dst_dtype in ('float8e4b8', 'float8e5b16') and (is_hip_cdna2() or is_hip_rdna4()):
347353
pytest.skip(f"{dst_dtype} is not supported on AMDGPU CDNA2 and RDNA4")
348354

@@ -373,6 +379,9 @@ def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, device, round
373379
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
374380
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
375381

382+
if dst_dtype in FP8_DTYPES and is_hip_rdna3():
383+
pytest.skip(f"{dst_dtype} is not supported on AMDGPU RDNA3")
384+
376385
if mode in ('inf', '-inf') and is_hip_rdna4():
377386
pytest.skip(f"clamping from `{mode}` is not supported on AMDGPU GFX12")
378387

0 commit comments

Comments
 (0)