|
7 | 7 | import triton |
8 | 8 | import triton.language as tl |
9 | 9 |
|
10 | | -from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_hip_rdna4 |
| 10 | +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_hip_rdna3, is_hip_rdna4 |
| 11 | + |
| 12 | +FP8_DTYPES = ('float8e5', 'float8e4b15', 'float8e4nv', 'float8e4b8', 'float8e5b16') |
11 | 13 |
|
12 | 14 |
|
13 | 15 | def matching_int(dtype): |
@@ -283,6 +285,8 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device): |
283 | 285 | launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) |
284 | 286 | return |
285 | 287 | elif is_hip(): |
| 288 | + if src_dtype in FP8_DTYPES and is_hip_rdna3(): |
| 289 | + pytest.skip(f"{src_dtype} is not supported on AMDGPU RDNA3") |
286 | 290 | if (src_dtype == 'float8e4nv' and not (is_hip_cdna3() or is_hip_cdna4())): |
287 | 291 | pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") |
288 | 292 | if src_dtype == 'float8e4b15': |
@@ -343,6 +347,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): |
343 | 347 | pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3") |
344 | 348 |
|
345 | 349 | if is_hip(): |
| 350 | + if dst_dtype in FP8_DTYPES and is_hip_rdna3(): |
| 351 | + pytest.skip(f"{dst_dtype} is not supported on AMDGPU RDNA3") |
346 | 352 | if dst_dtype in ('float8e4b8', 'float8e5b16') and (is_hip_cdna2() or is_hip_rdna4()): |
347 | 353 | pytest.skip(f"{dst_dtype} is not supported on AMDGPU CDNA2 and RDNA4") |
348 | 354 |
|
@@ -373,6 +379,9 @@ def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, device, round |
373 | 379 | if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): |
374 | 380 | pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") |
375 | 381 |
|
| 382 | + if dst_dtype in FP8_DTYPES and is_hip_rdna3(): |
| 383 | + pytest.skip(f"{dst_dtype} is not supported on AMDGPU RDNA3") |
| 384 | + |
376 | 385 | if mode in ('inf', '-inf') and is_hip_rdna4(): |
377 | 386 | pytest.skip(f"clamping from `{mode}` is not supported on AMDGPU GFX12") |
378 | 387 |
|
|
0 commit comments