From 83c74b96001553ec31757ded6969a86ebf2eeedb Mon Sep 17 00:00:00 2001 From: jeffniu-openai Date: Thu, 11 Jun 2026 18:02:11 -0700 Subject: [PATCH] [FPSAN] Broaden FPSan MMA dtype and minimum-shape coverage --- python/test/gluon/test_fpsan.py | 722 +++++++++++++++++++++----------- 1 file changed, 488 insertions(+), 234 deletions(-) diff --git a/python/test/gluon/test_fpsan.py b/python/test/gluon/test_fpsan.py index e98b7de1d550..34bfbd3760e6 100644 --- a/python/test/gluon/test_fpsan.py +++ b/python/test/gluon/test_fpsan.py @@ -9,6 +9,7 @@ from triton.experimental.gluon import language as gl from triton import language as tl from triton._internal_testing import is_blackwell, is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4, is_hip_gfx1250, is_hopper, is_interpreter +from triton.experimental.gluon.language.nvidia.ampere import mma_v2 from triton.experimental.gluon.language.nvidia import hopper from triton.experimental.gluon.language.nvidia.blackwell import ( TensorMemoryLayout, @@ -130,6 +131,77 @@ def _signed_cast_payload_u64(payload, src_bitwidth: int, dst_bitwidth: int) -> n return np.where((x & sign) != 0, x | extension, x) & _low_mask_u64(dst_bitwidth) +_FLOAT_DTYPE_INFO = { + "f64": (64, 0x3FF0000000000000, np.int64, torch.int64, torch.float64, gl.float64), + "f32": (32, 0x3F800000, np.int32, torch.int32, torch.float32, gl.float32), + "f16": (16, 0x3C00, np.int16, torch.int16, torch.float16, gl.float16), + "bf16": (16, 0x3F80, np.int16, torch.int16, torch.bfloat16, gl.bfloat16), + "e4m3": (8, 0x38, np.int8, torch.int8, torch.float8_e4m3fn, gl.float8e4nv), + "e5m2": (8, 0x3C, np.int8, torch.int8, torch.float8_e5m2, gl.float8e5), + "e4m3fnuz": (8, 0x40, np.int8, torch.int8, torch.float8_e4m3fnuz, gl.float8e4b8), + "e5m2fnuz": (8, 0x40, np.int8, torch.int8, torch.float8_e5m2fnuz, gl.float8e5b16), +} + + +def _float_dtype_info(dtype: str): + return _FLOAT_DTYPE_INFO[dtype] + + +def _float_payload_edges(bitwidth: int) -> np.ndarray: + if bitwidth == 8: + return np.asarray([0x00, 0x01, 0x7F, 0x80, 0x81, 0xFF], dtype=np.uint64) + if bitwidth == 16: + return np.asarray([0x0000, 0x0001, 0x00FF, 0x0100, 0x7FFF, 0x8000, 0x8001, 0xFFFF], dtype=np.uint64) + if bitwidth == 32: + return np.asarray([ + 0x00000000, 0x00000001, 0x000000FF, 0x00000100, 0x0000FFFF, 0x00010000, 0x7FFFFFFF, 0x80000000, 0x80000001, + 0xFFFFFFFF + ], dtype=np.uint64) + assert bitwidth == 64 + return np.asarray([ + 0x0000000000000000, + 0x0000000000000001, + 0x00000000000000FF, + 0x0000000000000100, + 0x00000000FFFFFFFF, + 0x0000000100000000, + 0x7FFFFFFFFFFFFFFF, + 0x8000000000000000, + 0x8000000000000001, + 0xFFFFFFFFFFFFFFFF, + ], dtype=np.uint64) + + +def _random_float_bits(rs: np.random.RandomState, shape, dtype: str) -> np.ndarray: + bitwidth, one_bits, np_storage_dtype, _, _, _ = _float_dtype_info(dtype) + high = np.iinfo(np.uint64).max if bitwidth == 64 else 1 << bitwidth + payload = rs.randint(0, high, size=shape, dtype=np.uint64) + edges = _float_payload_edges(bitwidth) + edge_count = min(payload.size, len(edges)) + payload.reshape(-1)[:edge_count] = edges[:edge_count] + bits = _unmix_payload_u64_to_float_bits(payload, bitwidth, one_bits) + np_unsigned_dtype = np.dtype(f"u{bitwidth // 8}") + return bits.astype(np_unsigned_dtype).view(np_storage_dtype) + + +def _as_float_bits_tensor(bits: np.ndarray, dtype: str): + _, _, _, torch_storage_dtype, torch_dtype, _ = _float_dtype_info(dtype) + storage = torch.tensor(bits, device="cuda", dtype=torch_storage_dtype) + return storage, triton.TensorWrapper(storage, dtype=torch_dtype) + + +def _mix_float_bits(bits: np.ndarray, dtype: str) -> np.ndarray: + bitwidth, one_bits, _, _, _, _ = _float_dtype_info(dtype) + return _mix_float_bits_to_payload_u64(bits, bitwidth, one_bits) + + +def _unmix_payload_to_float_bits(payload: np.ndarray, dtype: str) -> np.ndarray: + bitwidth, one_bits, np_storage_dtype, _, _, _ = _float_dtype_info(dtype) + bits = _unmix_payload_u64_to_float_bits(payload, bitwidth, one_bits) + np_unsigned_dtype = np.dtype(f"u{bitwidth // 8}") + return bits.astype(np_unsigned_dtype).view(np_storage_dtype) + + def _payload_u32_to_f32_bits_i32(x_u64: np.ndarray) -> np.ndarray: return _unmix_payload_u32_to_f32_bits_i32((x_u64 & np.uint64(0xFFFFFFFF)).astype(np.uint32)) @@ -391,26 +463,22 @@ def _extern_backend_name() -> str: EXTERN_MIXED_CASES = MIXED_EXTERN_SYMBOLS[_extern_backend_name()] -def _as_payload_np_i32(x) -> np.ndarray: +def _as_payload_np_unsigned(x) -> np.ndarray: if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if not isinstance(x, np.ndarray): raise TypeError(f"unsupported input type: {type(x)}") - if x.dtype == np.int32: - return x.astype(np.int32, copy=False) - if x.dtype == np.uint32: - return x.view(np.int32) - if x.dtype == np.float32: - return x.view(np.int32) + if x.dtype.kind in "iuf" and x.dtype.itemsize in (1, 2, 4, 8): + return x.view(np.dtype(f"u{x.dtype.itemsize}")) raise TypeError(f"unsupported dtype for payload comparison: {x.dtype}") def _assert_payload_equal(actual, expected) -> None: - np.testing.assert_array_equal(_as_payload_np_i32(actual), _as_payload_np_i32(expected)) + np.testing.assert_array_equal(_as_payload_np_unsigned(actual), _as_payload_np_unsigned(expected)) def _payload_equal(a, b) -> bool: - return np.array_equal(_as_payload_np_i32(a), _as_payload_np_i32(b)) + return np.array_equal(_as_payload_np_unsigned(a), _as_payload_np_unsigned(b)) @gluon.jit @@ -1276,23 +1344,32 @@ def test_cast_ext_payload_semantics(device, fresh_knobs): _assert_payload_equal(out_np[:3], special_f32_bits) -def _mm_payload_u32(a_i32: np.ndarray, b_i32: np.ndarray, c_i32: np.ndarray = None) -> np.ndarray: - # Computes: c + a @ b in Z/(2^32) on mixed f32 payload bits. - a_u = _mix_f32_bits_to_payload_u32(a_i32).astype(np.uint64) - b_u = _mix_f32_bits_to_payload_u32(b_i32).astype(np.uint64) - c_u = _mix_f32_bits_to_payload_u32(c_i32).astype(np.uint64) if c_i32 is not None else None +def _mm_payload_bits(a_bits: np.ndarray, b_bits: np.ndarray, c_bits: np.ndarray, type_a: str, type_b: str, + acc_type: str) -> np.ndarray: + # Computes: c + a @ b in Z/(2^acc_width) on mixed float payload bits. + a_width = _float_dtype_info(type_a)[0] + b_width = _float_dtype_info(type_b)[0] + acc_width = _float_dtype_info(acc_type)[0] + a_u = _signed_cast_payload_u64(_mix_float_bits(a_bits, type_a), a_width, acc_width) + b_u = _signed_cast_payload_u64(_mix_float_bits(b_bits, type_b), b_width, acc_width) + c_u = _mix_float_bits(c_bits, acc_type) if c_bits is not None else None m, k = a_u.shape k2, n = b_u.shape assert k == k2 out = np.empty((m, n), dtype=np.uint64) - mask = np.uint64(0xFFFFFFFF) - for i in range(m): - for j in range(n): - s = c_u[i, j] if c_u is not None else 0 - for kk in range(k): - s = (s + (a_u[i, kk] * b_u[kk, j])) & mask - out[i, j] = s - return _unmix_payload_u32_to_f32_bits_i32(out.astype(np.uint32)) + mask = _low_mask_u64(acc_width) + with np.errstate(over="ignore"): + for i in range(m): + for j in range(n): + s = c_u[i, j] if c_u is not None else 0 + for kk in range(k): + s = (s + (a_u[i, kk] * b_u[kk, j])) & mask + out[i, j] = s + return _unmix_payload_to_float_bits(out, acc_type) + + +def _mm_payload_u32(a_i32: np.ndarray, b_i32: np.ndarray, c_i32: np.ndarray = None) -> np.ndarray: + return _mm_payload_bits(a_i32, b_i32, c_i32, "f32", "f32", "f32") def _bmm_payload_u32(a_i32: np.ndarray, b_i32: np.ndarray, c_i32: np.ndarray = None) -> np.ndarray: @@ -1374,6 +1451,9 @@ def _dot_scaled_payload_u32(a_data: np.ndarray, b_data: np.ndarray, a_scale, b_s M, N = a_data.shape[0], b_data.shape[1] K = a_data.shape[1] * a_pack compute_type = "fp16" if "fp16" in (type_a, type_b) else "bf16" + # CDNA4 converts raw E8M0 scale bytes to bf16 before scaled-upcast, even + # when the scaled-upcast result uses fp16. + scale_compute_type = "bf16" if is_hip_cdna4() else compute_type compute_mask = np.uint64(0xFFFF) mask = np.uint64(0xFFFFFFFF) out = np.zeros((M, N), dtype=np.uint64) @@ -1385,10 +1465,10 @@ def _dot_scaled_payload_u32(a_data: np.ndarray, b_data: np.ndarray, a_scale, b_s a_val = _dot_scaled_compute_payload_elem(a_val, type_a, compute_type) b_val = _dot_scaled_compute_payload_elem(b_val, type_b, compute_type) if a_scale is not None: - a_scale_val = _dot_scaled_scale_payload(np.uint64(a_scale[i, kk // 32]), compute_type) + a_scale_val = _dot_scaled_scale_payload(np.uint64(a_scale[i, kk // 32]), scale_compute_type) a_val = (a_val * a_scale_val) & compute_mask if b_scale is not None: - b_scale_val = _dot_scaled_scale_payload(np.uint64(b_scale[j, kk // 32]), compute_type) + b_scale_val = _dot_scaled_scale_payload(np.uint64(b_scale[j, kk // 32]), scale_compute_type) b_val = (b_val * b_scale_val) & compute_mask a_val = _signed_cast_payload_scalar(a_val, 16, 32) b_val = _signed_cast_payload_scalar(b_val, 16, 32) @@ -1398,8 +1478,8 @@ def _dot_scaled_payload_u32(a_data: np.ndarray, b_data: np.ndarray, a_scale, b_s def _mm_scaled_payload_u32(a_u8: np.ndarray, b_u8: np.ndarray, a_scale_u8: np.ndarray, b_scale_u8: np.ndarray, - c_i32: np.ndarray = None, a_pack: int = 1, b_pack: int = 1, - elem_type: str = "e2m1") -> np.ndarray: + c_i32: np.ndarray = None, a_pack: int = 1, b_pack: int = 1, type_a: str = "e2m1", + type_b: str = "e2m1", scale_factor: int = 32, scale_type: str = "e8m0") -> np.ndarray: a_scale = a_scale_u8.astype(np.uint64) b_scale = b_scale_u8.astype(np.uint64) c_u = _mix_f32_bits_to_payload_u32(c_i32).astype(np.uint64) if c_i32 is not None else None @@ -1408,8 +1488,8 @@ def _mm_scaled_payload_u32(a_u8: np.ndarray, b_u8: np.ndarray, a_scale_u8: np.nd n = b_u8.shape[1] k = a_u8.shape[1] * a_pack assert k == b_u8.shape[0] * b_pack - assert a_scale.shape == (m, k // 32) - assert b_scale.shape == (n, k // 32) + assert a_scale.shape == (m, k // scale_factor) + assert b_scale.shape == (n, k // scale_factor) def unpack_payload_matrix(data: np.ndarray, pack: int, pack_axis: int) -> np.ndarray: if pack == 1: @@ -1425,7 +1505,7 @@ def unpack_payload_matrix(data: np.ndarray, pack: int, pack_axis: int) -> np.nda out[1::2, :] = (data.astype(np.uint64) >> np.uint64(4)) & np.uint64(0x0F) return out - def compute_payload_matrix(data: np.ndarray) -> np.ndarray: + def compute_payload_matrix(data: np.ndarray, elem_type: str) -> np.ndarray: if elem_type in ("e4m3", "e5m2"): one_bits = 0x38 if elem_type == "e4m3" else 0x3C payload = _mix_float_bits_to_payload_u64(data, 8, one_bits) @@ -1433,20 +1513,24 @@ def compute_payload_matrix(data: np.ndarray) -> np.ndarray: return data & np.uint64(0xFFFF) def scale_payload_matrix(raw_scale: np.ndarray) -> np.ndarray: + if scale_type == "e4m3": + payload = _mix_float_bits_to_payload_u64(raw_scale, 8, 0x38) + return _signed_cast_payload_u64(payload, 8, 16) + assert scale_type == "e8m0" raw_bf16 = (raw_scale & np.uint64(0xFF)) << np.uint64(7) return _mix_float_bits_to_payload_u64(raw_bf16, 16, 0x3F80) - a_payload = compute_payload_matrix(unpack_payload_matrix(a_u8, a_pack, pack_axis=1)) - b_payload = compute_payload_matrix(unpack_payload_matrix(b_u8, b_pack, pack_axis=0)) + a_payload = compute_payload_matrix(unpack_payload_matrix(a_u8, a_pack, pack_axis=1), type_a) + b_payload = compute_payload_matrix(unpack_payload_matrix(b_u8, b_pack, pack_axis=0), type_b) a_scale_payload = scale_payload_matrix(a_scale) b_scale_payload = scale_payload_matrix(b_scale) out = c_u.copy() if c_u is not None else np.zeros((m, n), dtype=np.uint64) compute_mask = np.uint64(0xFFFF) mask32 = np.uint64(0xFFFFFFFF) - for group in range(k // 32): - start = group * 32 - end = start + 32 + for group in range(k // scale_factor): + start = group * scale_factor + end = start + scale_factor lhs = (a_payload[:, start:end] * a_scale_payload[:, group:group + 1]) & compute_mask rhs = (b_payload[start:end, :] * b_scale_payload[:, group][None, :]) & compute_mask lhs = _signed_cast_payload_u64(lhs, 16, 32) @@ -1455,11 +1539,110 @@ def scale_payload_matrix(raw_scale: np.ndarray) -> np.ndarray: return _unmix_payload_u32_to_f32_bits_i32(out.astype(np.uint32)) -def test_dot_fma(device, fresh_knobs): +_DOT_FLOAT_DTYPES = [ + ("f32", "f32", "f32"), + ("bf16", "bf16", "f32"), + ("f16", "f16", "f16"), + ("f16", "f16", "f32"), + *[(type_a, type_b, acc_type) + for type_a, type_b, acc_type in itertools.product(("e4m3", "e5m2"), ("e4m3", "e5m2"), ("f16", "f32"))], +] + +_DOT_FMA_DTYPES = [ + *_DOT_FLOAT_DTYPES, + ("f64", "f64", "f64"), +] + +_TCGEN05_FLOAT_DTYPES = [ + *_DOT_FLOAT_DTYPES, + ("f16", "bf16", "f32"), + ("bf16", "f16", "f32"), +] + +_TCGEN05_SCALED_DTYPES = list(itertools.product(("e2m1", "e4m3", "e5m2"), repeat=2)) + +_MFMA_FP8_DTYPES = ("e4m3fnuz", "e5m2fnuz") if is_hip_cdna3() else ("e4m3", "e5m2") + +_MFMA_DOT_CASES = [ + pytest.param("f32", "f32", "f32", 16, 16, 32, 32, 32, 8 if is_hip_cdna3() else 16, 4 if is_hip_cdna3() else 8, + id="f32-f32-f32-broad"), + pytest.param("f64", "f64", "f64", 16, 16, 4, 16, 16, 4, 1, id="f64-f64-f64-minimum"), + pytest.param("f32", "f32", "f32", 16, 16, 4, 16, 16, 4, 1, id="f32-f32-f32-minimum"), + pytest.param("f16", "f16", "f32", 16, 16, 16, 16, 16, 16, 4, id="f16-f16-f32-minimum"), + pytest.param("bf16", "bf16", "f32", 16, 16, 16, 16, 16, 16, 4, id="bf16-bf16-f32-minimum"), + *[ + pytest.param(type_a, type_b, "f32", 16, 16, 32, 16, 16, 32, 8, id=f"{type_a}-{type_b}-f32-minimum") + for type_a, type_b in itertools.product(_MFMA_FP8_DTYPES, repeat=2) + ], +] + +_WMMA_DOT_CASES = [ + pytest.param("f32", "f32", "f32", 32, 32, 32, 4, 2, id="f32-f32-f32-broad"), + pytest.param("f32", "f32", "f32", 16, 16, 4, 4, 2, id="f32-f32-f32-minimum"), + pytest.param("f16", "f16", "f32", 16, 16, 32, 32, 8, id="f16-f16-f32-minimum"), + pytest.param("bf16", "bf16", "f32", 16, 16, 32, 32, 8, id="bf16-bf16-f32-minimum"), + *[ + pytest.param(type_a, type_b, "f32", 16, 16, 64, 64, 8, id=f"{type_a}-{type_b}-f32-minimum") + for type_a, type_b in itertools.product(("e4m3", "e5m2"), repeat=2) + ], +] + + +def _native_mma_k(type_a: str) -> int: + return 256 // _float_dtype_info(type_a)[0] + + +_DOT_FMA_CASES = [ + *[pytest.param(*dtypes, 32, 32, 32, id=f"{'-'.join(dtypes)}-broad") for dtypes in _DOT_FMA_DTYPES], + *[ + pytest.param(*dtypes, 1, 1, _native_mma_k(dtypes[0]), id=f"{'-'.join(dtypes)}-minimum") + for dtypes in _DOT_FMA_DTYPES + ], +] + +_MMA_V2_CASES = [ + pytest.param(*dtypes, 8 if dtypes[0] == "f64" else 16, 8, _native_mma_k(dtypes[0]), 8 if dtypes[0] == "f64" else 16, + id="-".join(dtypes)) for dtypes in _DOT_FMA_DTYPES +] + +_WARP_GROUP_MMA_CASES = [ + *[pytest.param(*dtypes, 64, 64, 64, 32, id=f"{'-'.join(dtypes)}-broad") for dtypes in _DOT_FLOAT_DTYPES], + *[ + pytest.param(*dtypes, 64, 8, _native_mma_k(dtypes[0]), 8, id=f"{'-'.join(dtypes)}-minimum") + for dtypes in _DOT_FLOAT_DTYPES + ], +] + +_TCGEN05_MMA_CASES = [ + *[pytest.param(*dtypes, 64, 64, 64, id=f"{'-'.join(dtypes)}-broad") for dtypes in _TCGEN05_FLOAT_DTYPES], + *[ + pytest.param(*dtypes, 64, 8, _native_mma_k(dtypes[0]), id=f"{'-'.join(dtypes)}-minimum") + for dtypes in _TCGEN05_FLOAT_DTYPES + ], +] + +_TCGEN05_MMA_SCALED_CASES = [ + *[ + pytest.param(type_a, type_b, 128, 128, 128, 32, "e8m0", id=f"{type_a}-{type_b}-broad") + for type_a, type_b in _TCGEN05_SCALED_DTYPES + ], + *[ + pytest.param(type_a, type_b, 128, 128, 64 if type_a == type_b == "e2m1" else 32, 32, "e8m0", + id=f"{type_a}-{type_b}-mxfp-minimum") for type_a, type_b in _TCGEN05_SCALED_DTYPES + ], + pytest.param("e2m1", "e2m1", 128, 128, 64, 16, "e4m3", id="e2m1-e2m1-nvfp4-minimum"), +] + + +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k"), _DOT_FMA_CASES) +def test_dot_fma(device, type_a, type_b, acc_type, m, n, k, fresh_knobs): _require_cuda_backend(device) + if is_cuda() and torch.cuda.get_device_capability()[0] < 9 and "e4m3" in (type_a, type_b): + pytest.skip("E4M3 requires Hopper or newer") - B = 16 - BLOCK = gl.constexpr(B) + M = gl.constexpr(m) + N = gl.constexpr(n) + K = gl.constexpr(k) fresh_knobs.compilation.instrumentation_mode = "fpsan" @@ -1469,15 +1652,13 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, THREADS_PER_WARP: gl.constexpr): lhs_layout: gl.constexpr = gl.DotOperandLayout(parent=layout, operand_index=0, k_width=0) rhs_layout: gl.constexpr = gl.DotOperandLayout(parent=layout, operand_index=1, k_width=0) - offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] - # Important: build separate offsets for A and B. - # dot_fma expects operands to represent A[M,K] and B[K,N]. Using the same - # linearized (m,n) offsets for both makes B effectively transposed. - offs_k = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] - a_offs = offs_m * BLOCK + offs_k - b_offs = offs_n * BLOCK + offs_m # load B^T so dot_fma produces A @ B - out_offs = offs_m * BLOCK + offs_n + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + offs_k_row = gl.arange(0, K, layout=gl.SliceLayout(1, layout))[:, None] + offs_k_col = gl.arange(0, K, layout=gl.SliceLayout(0, layout))[None, :] + a_offs = offs_m * K + offs_k_col + b_offs = offs_n * K + offs_k_row + out_offs = offs_m * N + offs_n a = gl.convert_layout(gl.load(a_ptr + a_offs), lhs_layout) b = gl.convert_layout(gl.load(b_ptr + b_offs), rhs_layout) @@ -1486,21 +1667,15 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, THREADS_PER_WARP: gl.constexpr): gl.store(out_ptr + out_offs, out) rs = np.random.RandomState(0) - a_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - b_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - exp_bits = _mm_payload_u32(a_bits, b_bits.T, c_bits) - - a = torch.tensor(a_bits, device="cuda", dtype=torch.int32) - b = torch.tensor(b_bits, device="cuda", dtype=torch.int32) - c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((B, B), device="cuda", dtype=torch.int32) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (n, k), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits.T, c_bits, type_a, type_b, acc_type) - # Wrap int storage as fp32 so fpsan operates on payload bits. - aw = triton.TensorWrapper(a, dtype=torch.float32) - bw = triton.TensorWrapper(b, dtype=torch.float32) - cw = triton.TensorWrapper(c, dtype=torch.float32) - outw = triton.TensorWrapper(out, dtype=torch.float32) + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) compiled = kernel[(1, )](aw, bw, cw, outw, THREADS_PER_WARP=THREADS_PER_WARP) ttgir = compiled.asm["ttgir"] @@ -1510,6 +1685,62 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, THREADS_PER_WARP: gl.constexpr): _assert_payload_equal(out, exp_bits) +@pytest.mark.skipif(not is_cuda(), reason="Requires NVIDIA MMA v2") +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k", "instr_m"), _MMA_V2_CASES) +def test_mma_v2(device, type_a, type_b, acc_type, m, n, k, instr_m, fresh_knobs): + _require_cuda_backend(device) + if torch.cuda.get_device_capability()[0] < 9 and "e4m3" in (type_a, type_b): + pytest.skip("E4M3 requires Hopper or newer") + + M = gl.constexpr(m) + N = gl.constexpr(n) + K = gl.constexpr(k) + + fresh_knobs.compilation.instrumentation_mode = "fpsan" + + @gluon.jit + def kernel(a_ptr, b_ptr, c_ptr, out_ptr, A_K_WIDTH: gl.constexpr, B_K_WIDTH: gl.constexpr, INSTR_M: gl.constexpr, + PRECISION: gl.constexpr, THREADS_PER_WARP: gl.constexpr): + layout: gl.constexpr = gl.BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [4, 1], [1, 0]) + acc_layout: gl.constexpr = gl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], + instr_shape=[INSTR_M, 8]) + lhs_layout: gl.constexpr = gl.DotOperandLayout(parent=acc_layout, operand_index=0, k_width=A_K_WIDTH) + rhs_layout: gl.constexpr = gl.DotOperandLayout(parent=acc_layout, operand_index=1, k_width=B_K_WIDTH) + + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + offs_k_row = gl.arange(0, K, layout=gl.SliceLayout(1, layout))[:, None] + offs_k_col = gl.arange(0, K, layout=gl.SliceLayout(0, layout))[None, :] + a_offs = offs_m * K + offs_k_col + b_offs = offs_k_row * N + offs_n + out_offs = offs_m * N + offs_n + + a = gl.convert_layout(gl.load(a_ptr + a_offs), lhs_layout) + b = gl.convert_layout(gl.load(b_ptr + b_offs), rhs_layout) + c = gl.convert_layout(gl.load(c_ptr + out_offs), acc_layout) + out = mma_v2(a, b, c, input_precision=PRECISION) + gl.store(out_ptr + out_offs, gl.convert_layout(out, layout)) + + rs = np.random.RandomState(0) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (k, n), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits, c_bits, type_a, type_b, acc_type) + + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) + + a_width = _float_dtype_info(type_a)[0] + b_width = _float_dtype_info(type_b)[0] + precision = "tf32" if type_a == "f32" else "ieee" + kernel[(1, )](aw, bw, cw, outw, A_K_WIDTH=max(32 // a_width, 1), B_K_WIDTH=max(32 // b_width, 1), INSTR_M=instr_m, + PRECISION=precision, THREADS_PER_WARP=THREADS_PER_WARP) + + _assert_payload_equal(out, exp_bits) + + def test_dot_fma_batched(device, fresh_knobs): _require_cuda_backend(device) @@ -1568,64 +1799,67 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, THREADS_PER_WARP: gl.constexpr): @pytest.mark.skipif(not is_hopper(), reason="Requires Hopper") @pytest.mark.parametrize(("use_acc", "is_async"), [(False, False), (True, False), (True, True)]) -def test_warpgroup_mma(device, use_acc, is_async, fresh_knobs): +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k", "instr_n"), _WARP_GROUP_MMA_CASES) +def test_warpgroup_mma(device, use_acc, is_async, type_a, type_b, acc_type, m, n, k, instr_n, fresh_knobs): _require_cuda_backend(device) - B = 64 - BLOCK = gl.constexpr(B) + M = gl.constexpr(m) + N = gl.constexpr(n) + K = gl.constexpr(k) fresh_knobs.compilation.instrumentation_mode = "fpsan" @gluon.jit - def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr, IS_ASYNC: gl.constexpr): + def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr, IS_ASYNC: gl.constexpr, A_DTYPE: gl.constexpr, + B_DTYPE: gl.constexpr, INSTR_N: gl.constexpr, INSTR_K: gl.constexpr, PRECISION: gl.constexpr): layout: gl.constexpr = gl.BlockedLayout([1, 1], [32, 1], [gl.num_warps(), 1], [1, 0]) acc_layout: gl.constexpr = gl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], - instr_shape=[16, 32, 16]) + instr_shape=[16, INSTR_N, INSTR_K]) - offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] - offs_k_row = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_k_col = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + offs_k_row = gl.arange(0, K, layout=gl.SliceLayout(1, layout))[:, None] + offs_k_col = gl.arange(0, K, layout=gl.SliceLayout(0, layout))[None, :] - a_tile = gl.load(a_ptr + offs_m * BLOCK + offs_k_col) - b_tile = gl.load(b_ptr + offs_k_row * BLOCK + offs_n) - c_tile = gl.load(c_ptr + offs_m * BLOCK + offs_n) + a_tile = gl.load(a_ptr + offs_m * K + offs_k_col) + b_tile = gl.load(b_ptr + offs_k_row * N + offs_n) + c_tile = gl.load(c_ptr + offs_m * N + offs_n) - smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK, BLOCK], gl.float32) - smem_a = gl.allocate_shared_memory(gl.float32, [BLOCK, BLOCK], smem_layout, a_tile) - smem_b = gl.allocate_shared_memory(gl.float32, [BLOCK, BLOCK], smem_layout, b_tile) + smem_layout_a: gl.constexpr = gl.NVMMASharedLayout.get_default_for([M, K], A_DTYPE) + smem_layout_b: gl.constexpr = gl.NVMMASharedLayout.get_default_for([K, N], B_DTYPE) + smem_a = gl.allocate_shared_memory(A_DTYPE, [M, K], smem_layout_a, a_tile) + smem_b = gl.allocate_shared_memory(B_DTYPE, [K, N], smem_layout_b, b_tile) acc = gl.convert_layout(c_tile, acc_layout) - acc = hopper.warpgroup_mma(smem_a, smem_b, acc, use_acc=USE_ACC, precision="tf32", is_async=IS_ASYNC) + acc = hopper.warpgroup_mma(smem_a, smem_b, acc, use_acc=USE_ACC, precision=PRECISION, is_async=IS_ASYNC) if IS_ASYNC: acc = hopper.warpgroup_mma_wait(num_outstanding=0, deps=[acc]) out = gl.convert_layout(acc, layout) - gl.store(out_ptr + offs_m * BLOCK + offs_n, out) + gl.store(out_ptr + offs_m * N + offs_n, out) rs = np.random.RandomState(0) - a_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - b_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - exp_bits = _mm_payload_u32(a_bits, b_bits, c_bits if use_acc else None) - - a = torch.tensor(a_bits, device="cuda", dtype=torch.int32) - b = torch.tensor(b_bits, device="cuda", dtype=torch.int32) - c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((B, B), device="cuda", dtype=torch.int32) - - aw = triton.TensorWrapper(a, dtype=torch.float32) - bw = triton.TensorWrapper(b, dtype=torch.float32) - cw = triton.TensorWrapper(c, dtype=torch.float32) - outw = triton.TensorWrapper(out, dtype=torch.float32) - - kernel[(1, )](aw, bw, cw, outw, USE_ACC=use_acc, IS_ASYNC=is_async) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (k, n), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits, c_bits if use_acc else None, type_a, type_b, acc_type) + + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) + + a_width, _, _, _, _, a_dtype = _float_dtype_info(type_a) + _, _, _, _, _, b_dtype = _float_dtype_info(type_b) + precision = "tf32" if type_a == "f32" else "ieee" + kernel[(1, )](aw, bw, cw, outw, USE_ACC=use_acc, IS_ASYNC=is_async, A_DTYPE=a_dtype, B_DTYPE=b_dtype, + INSTR_N=instr_n, INSTR_K=256 // a_width, PRECISION=precision) _assert_payload_equal(out, exp_bits) @pytest.mark.skipif(not (is_hip_cdna4() or is_hip_gfx1250()), reason="Requires DotScaledOp support (CDNA4, or GFX1250)") -@pytest.mark.parametrize("type_a", ["e2m1", "e4m3", "e5m2"]) -@pytest.mark.parametrize("type_b", ["e2m1", "e4m3", "e5m2", "bf16"]) +@pytest.mark.parametrize("type_a", ["e2m1", "e4m3", "e5m2", "bf16", "fp16"]) +@pytest.mark.parametrize("type_b", ["e2m1", "e4m3", "e5m2", "bf16", "fp16"]) def test_dot_scaled(device, type_a, type_b, fresh_knobs): _require_cuda_backend(device) @@ -1680,12 +1914,18 @@ def kernel(a_ptr, a_scale_ptr, b_ptr, b_scale_ptr, out_ptr, BLOCK_M: tl.constexp a_scale = torch.tensor(a_scale_bits, device="cuda", dtype=torch.uint8) b_scale = torch.tensor(b_scale_bits, device="cuda", dtype=torch.uint8) - if type_b == "bf16": + if type_a in ("bf16", "fp16"): + a_bits = rs.randint(0, 65536, size=(B, packed_k_a)).astype(np.uint16) + a = torch.tensor(a_bits, device="cuda", + dtype=torch.uint16).view(torch.bfloat16 if type_a == "bf16" else torch.float16) + if type_b in ("bf16", "fp16"): b_bits = rs.randint(0, 65536, size=(packed_k_b, B)).astype(np.uint16) - b = torch.tensor(b_bits, device="cuda", dtype=torch.uint16).view(torch.bfloat16) + b = torch.tensor(b_bits, device="cuda", + dtype=torch.uint16).view(torch.bfloat16 if type_b == "bf16" else torch.float16) - exp_bits = _dot_scaled_payload_u32(a_bits, b_bits, a_scale_bits, None if type_b == "bf16" else b_scale_bits, a_pack, - b_pack, type_a, type_b) + exp_bits = _dot_scaled_payload_u32(a_bits, b_bits, None if type_a in ("bf16", "fp16") else a_scale_bits, + None if type_b in ("bf16", "fp16") else b_scale_bits, a_pack, b_pack, type_a, + type_b) out = torch.empty((B, B), device="cuda", dtype=torch.int32) outw = triton.TensorWrapper(out, dtype=torch.float32) @@ -1697,39 +1937,42 @@ def kernel(a_ptr, a_scale_ptr, b_ptr, b_scale_ptr, out_ptr, BLOCK_M: tl.constexp @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") @pytest.mark.parametrize("use_acc", [False, True]) -def test_tcgen05_mma(device, use_acc, fresh_knobs): +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k"), _TCGEN05_MMA_CASES) +def test_tcgen05_mma(device, use_acc, type_a, type_b, acc_type, m, n, k, fresh_knobs): _require_cuda_backend(device) - B = 64 - BLOCK = gl.constexpr(B) + M = gl.constexpr(m) + N = gl.constexpr(n) + K = gl.constexpr(k) fresh_knobs.compilation.instrumentation_mode = "fpsan" @gluon.jit - def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr): + def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr, A_DTYPE: gl.constexpr, B_DTYPE: gl.constexpr, + ACC_DTYPE: gl.constexpr, ACC_BITWIDTH: gl.constexpr): layout: gl.constexpr = gl.BlockedLayout([1, 1], [32, 1], [gl.num_warps(), 1], [1, 0]) - offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] - offs_k_row = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_k_col = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + offs_n_row = gl.arange(0, N, layout=gl.SliceLayout(1, layout))[:, None] + offs_k_col = gl.arange(0, K, layout=gl.SliceLayout(0, layout))[None, :] - a_offs = offs_m * BLOCK + offs_k_col - b_offs = offs_k_row * BLOCK + offs_n - out_offs = offs_m * BLOCK + offs_n + a_offs = offs_m * K + offs_k_col + b_offs = offs_n_row * K + offs_k_col + out_offs = offs_m * N + offs_n a_tile = gl.load(a_ptr + a_offs) b_tile = gl.load(b_ptr + b_offs) - smem_layout_a: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK, BLOCK], gl.float32) - smem_layout_b: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK, BLOCK], gl.float32) - smem_a = gl.allocate_shared_memory(gl.float32, [BLOCK, BLOCK], smem_layout_a) - smem_b = gl.allocate_shared_memory(gl.float32, [BLOCK, BLOCK], smem_layout_b) + smem_layout_a: gl.constexpr = gl.NVMMASharedLayout.get_default_for([M, K], A_DTYPE) + smem_layout_b: gl.constexpr = gl.NVMMASharedLayout.get_default_for([N, K], B_DTYPE) + smem_a = gl.allocate_shared_memory(A_DTYPE, [M, K], smem_layout_a) + smem_b = gl.allocate_shared_memory(B_DTYPE, [N, K], smem_layout_b) smem_a.store(a_tile) smem_b.store(b_tile) - tmem_layout: gl.constexpr = TensorMemoryLayout((BLOCK, BLOCK), col_stride=1) - acc_tmem = allocate_tensor_memory(gl.float32, [BLOCK, BLOCK], layout=tmem_layout) + tmem_layout: gl.constexpr = TensorMemoryLayout((M, N), col_stride=32 // ACC_BITWIDTH) + acc_tmem = allocate_tensor_memory(ACC_DTYPE, [M, N], layout=tmem_layout) acc_reg_layout: gl.constexpr = acc_tmem.get_reg_layout() if USE_ACC: c_tile = gl.load(c_ptr + out_offs) @@ -1750,22 +1993,21 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr): gl.store(out_ptr + out_offs, out) rs = np.random.RandomState(0) - a_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - b_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - exp_bits = _mm_payload_u32(a_bits, b_bits.T, c_bits if use_acc else None) - - a = torch.tensor(a_bits, device="cuda", dtype=torch.int32) - b = torch.tensor(b_bits, device="cuda", dtype=torch.int32) - c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((B, B), device="cuda", dtype=torch.int32) - - aw = triton.TensorWrapper(a, dtype=torch.float32) - bw = triton.TensorWrapper(b, dtype=torch.float32) - cw = triton.TensorWrapper(c, dtype=torch.float32) - outw = triton.TensorWrapper(out, dtype=torch.float32) - - kernel[(1, )](aw, bw, cw, outw, USE_ACC=use_acc) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (n, k), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits.T, c_bits if use_acc else None, type_a, type_b, acc_type) + + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) + + a_dtype = _float_dtype_info(type_a)[5] + b_dtype = _float_dtype_info(type_b)[5] + acc_bitwidth, _, _, _, _, acc_dtype = _float_dtype_info(acc_type) + kernel[(1, )](aw, bw, cw, outw, USE_ACC=use_acc, A_DTYPE=a_dtype, B_DTYPE=b_dtype, ACC_DTYPE=acc_dtype, + ACC_BITWIDTH=acc_bitwidth) _assert_payload_equal(out, exp_bits) @@ -1914,103 +2156,126 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, USE_ACC: gl.constexpr): @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") -@pytest.mark.parametrize("elem_type", ["e2m1", "e4m3", "e5m2"]) -def test_tcgen05_mma_scaled(device, elem_type, fresh_knobs): +@pytest.mark.parametrize(("type_a", "type_b", "m", "n", "k", "scale_factor", "scale_type"), _TCGEN05_MMA_SCALED_CASES) +def test_tcgen05_mma_scaled(device, type_a, type_b, m, n, k, scale_factor, scale_type, fresh_knobs): _require_cuda_backend(device) - B = 128 - BLOCK = gl.constexpr(B) - SCALE_K = gl.constexpr(B // 32) + M = gl.constexpr(m) + N = gl.constexpr(n) + K = gl.constexpr(k) + SCALE_K = gl.constexpr(k // scale_factor) fresh_knobs.compilation.instrumentation_mode = "fpsan" @gluon.jit - def kernel(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, c_ptr, out_ptr, TYPE: gl.constexpr): + def kernel(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, c_ptr, out_ptr, TYPE_A: gl.constexpr, TYPE_B: gl.constexpr, + SCALE_DTYPE: gl.constexpr): layout: gl.constexpr = gl.BlockedLayout([1, 1], [32, 1], [gl.num_warps(), 1], [1, 0]) - IS_FP4: gl.constexpr = TYPE == "e2m1" - PACK_FACTOR: gl.constexpr = 2 if IS_FP4 else 1 - PACKED_K: gl.constexpr = BLOCK // PACK_FACTOR - ELEM_DTYPE: gl.constexpr = gl.uint8 if IS_FP4 else (gl.float8e4nv if TYPE == "e4m3" else gl.float8e5) - a_nvmma_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK, PACKED_K], ELEM_DTYPE) - b_nvmma_layout: gl.constexpr = (gl.NVMMASharedLayout.get_default_for([BLOCK, PACKED_K], ELEM_DTYPE) - if IS_FP4 else gl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False, - element_bitwidth=8, rank=2)) + IS_FP4_A: gl.constexpr = TYPE_A == "e2m1" + IS_FP4_B: gl.constexpr = TYPE_B == "e2m1" + PACK_FACTOR_A: gl.constexpr = 2 if IS_FP4_A else 1 + PACK_FACTOR_B: gl.constexpr = 2 if IS_FP4_B else 1 + PACKED_K_A: gl.constexpr = K // PACK_FACTOR_A + PACKED_K_B: gl.constexpr = K // PACK_FACTOR_B + ELEM_DTYPE_A: gl.constexpr = gl.uint8 if IS_FP4_A else (gl.float8e4nv if TYPE_A == "e4m3" else gl.float8e5) + ELEM_DTYPE_B: gl.constexpr = gl.uint8 if IS_FP4_B else (gl.float8e4nv if TYPE_B == "e4m3" else gl.float8e5) + a_nvmma_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([M, PACKED_K_A], ELEM_DTYPE_A) + b_nvmma_layout: gl.constexpr = (gl.NVMMASharedLayout.get_default_for([N, PACKED_K_B], ELEM_DTYPE_B) + if IS_FP4_B else gl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False, + element_bitwidth=8, rank=2)) scale_layout: gl.constexpr = TensorMemoryScalesLayout() - offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] - offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] - offs_k_row = gl.arange(0, PACKED_K, layout=gl.SliceLayout(1, layout))[:, None] - offs_k_col = gl.arange(0, PACKED_K, layout=gl.SliceLayout(0, layout))[None, :] - - a_tile = gl.load(a_ptr + offs_m * PACKED_K + offs_k_col) - c_tile = gl.load(c_ptr + offs_m * BLOCK + offs_n) - a_smem = gl.allocate_shared_memory(ELEM_DTYPE, [BLOCK, PACKED_K], a_nvmma_layout, a_tile) - if IS_FP4: - b_tile = gl.load(b_ptr + offs_m * PACKED_K + offs_k_col) - b_smem = gl.allocate_shared_memory(ELEM_DTYPE, [BLOCK, PACKED_K], b_nvmma_layout, b_tile) + offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, N, layout=gl.SliceLayout(0, layout))[None, :] + offs_n_row = gl.arange(0, N, layout=gl.SliceLayout(1, layout))[:, None] + offs_bk_row = gl.arange(0, PACKED_K_B, layout=gl.SliceLayout(1, layout))[:, None] + offs_ak_col = gl.arange(0, PACKED_K_A, layout=gl.SliceLayout(0, layout))[None, :] + offs_bk_col = gl.arange(0, PACKED_K_B, layout=gl.SliceLayout(0, layout))[None, :] + + a_tile = gl.load(a_ptr + offs_m * PACKED_K_A + offs_ak_col) + c_tile = gl.load(c_ptr + offs_m * N + offs_n) + a_smem = gl.allocate_shared_memory(ELEM_DTYPE_A, [M, PACKED_K_A], a_nvmma_layout, a_tile) + if IS_FP4_B: + b_tile = gl.load(b_ptr + offs_n_row * PACKED_K_B + offs_bk_col) + b_smem = gl.allocate_shared_memory(ELEM_DTYPE_B, [N, PACKED_K_B], b_nvmma_layout, b_tile) b_mma = b_smem.permute((1, 0)) else: - b_tile = gl.load(b_ptr + offs_k_row * BLOCK + offs_n) - b_smem = gl.allocate_shared_memory(ELEM_DTYPE, [PACKED_K, BLOCK], b_nvmma_layout, b_tile) + b_tile = gl.load(b_ptr + offs_bk_row * N + offs_n) + b_smem = gl.allocate_shared_memory(ELEM_DTYPE_B, [PACKED_K_B, N], b_nvmma_layout, b_tile) b_mma = b_smem - tmem_layout: gl.constexpr = TensorMemoryLayout((BLOCK, BLOCK), col_stride=1) - acc_tmem = allocate_tensor_memory(gl.float32, [BLOCK, BLOCK], layout=tmem_layout) + tmem_layout: gl.constexpr = TensorMemoryLayout((M, N), col_stride=1) + acc_tmem = allocate_tensor_memory(gl.float32, [M, N], layout=tmem_layout) acc_tmem.store(gl.convert_layout(c_tile, acc_tmem.get_reg_layout())) - a_scale_tmem = allocate_tensor_memory(gl.int8, [BLOCK, SCALE_K], layout=scale_layout) - b_scale_tmem = allocate_tensor_memory(gl.int8, [BLOCK, SCALE_K], layout=scale_layout) + a_scale_tmem = allocate_tensor_memory(SCALE_DTYPE, [M, SCALE_K], layout=scale_layout) + b_scale_tmem = allocate_tensor_memory(SCALE_DTYPE, [N, SCALE_K], layout=scale_layout) a_scale_reg_layout: gl.constexpr = a_scale_tmem.get_reg_layout() b_scale_reg_layout: gl.constexpr = b_scale_tmem.get_reg_layout() scale_offs_k = gl.arange(0, SCALE_K, layout=gl.SliceLayout(0, a_scale_reg_layout))[None, :] - scale_offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, a_scale_reg_layout))[:, None] - scale_offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, b_scale_reg_layout))[:, None] + scale_offs_m = gl.arange(0, M, layout=gl.SliceLayout(1, a_scale_reg_layout))[:, None] + scale_offs_n = gl.arange(0, N, layout=gl.SliceLayout(1, b_scale_reg_layout))[:, None] a_scale_tmem.store(gl.load(a_scale_ptr + scale_offs_m * SCALE_K + scale_offs_k)) b_scale_tmem.store(gl.load(b_scale_ptr + scale_offs_n * SCALE_K + scale_offs_k)) bar = gl.allocate_shared_memory(gl.int64, [1], gl.constexpr(mbarrier.MBarrierLayout())) mbarrier.init(bar, count=1) - tcgen05_mma_scaled(a_smem, b_mma, acc_tmem, a_scale_tmem, b_scale_tmem, TYPE, TYPE, use_acc=True, + tcgen05_mma_scaled(a_smem, b_mma, acc_tmem, a_scale_tmem, b_scale_tmem, TYPE_A, TYPE_B, use_acc=True, mbarriers=[bar]) mbarrier.wait(bar, phase=0) mbarrier.invalidate(bar) out = gl.convert_layout(acc_tmem.load(), layout) - gl.store(out_ptr + offs_m * BLOCK + offs_n, out) + gl.store(out_ptr + offs_m * N + offs_n, out) rs = np.random.RandomState(0) - pack_factor = 2 if elem_type == "e2m1" else 1 - packed_k = B // pack_factor - a_bits = rs.randint(0 if elem_type == "e2m1" else 20, 256 if elem_type == "e2m1" else 40, size=(B, packed_k), + pack_factor_a = 2 if type_a == "e2m1" else 1 + pack_factor_b = 2 if type_b == "e2m1" else 1 + packed_k_a = k // pack_factor_a + packed_k_b = k // pack_factor_b + a_bits = rs.randint(0 if type_a == "e2m1" else 20, 256 if type_a == "e2m1" else 40, size=(m, packed_k_a), dtype=np.uint8) - if elem_type == "e2m1": - b_bits = rs.randint(0, 256, size=(B, packed_k), dtype=np.uint8) + if type_b == "e2m1": + b_bits = rs.randint(0, 256, size=(n, packed_k_b), dtype=np.uint8) b_ref_bits = b_bits.T else: - b_bits = rs.randint(20, 40, size=(packed_k, B), dtype=np.uint8) + b_bits = rs.randint(20, 40, size=(packed_k_b, n), dtype=np.uint8) b_ref_bits = b_bits - a_scale_bits = rs.randint(1, 4, size=(B, B // 32), dtype=np.int8) - b_scale_bits = rs.randint(1, 4, size=(B, B // 32), dtype=np.int8) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) + a_scale_bits = rs.randint(1, 4, size=(m, k // scale_factor), dtype=np.int8) + b_scale_bits = rs.randint(1, 4, size=(n, k // scale_factor), dtype=np.int8) + if scale_type == "e4m3": + a_scale_bits = rs.randint(1, 0x40, size=(m, k // scale_factor), dtype=np.uint8) + b_scale_bits = rs.randint(1, 0x40, size=(n, k // scale_factor), dtype=np.uint8) + c_bits = rs.randint(-(2**31), 2**31 - 1, size=(m, n), dtype=np.int32) exp_bits = _mm_scaled_payload_u32(a_bits, b_ref_bits, a_scale_bits.view(np.uint8), b_scale_bits.view(np.uint8), - c_bits, a_pack=pack_factor, b_pack=pack_factor, elem_type=elem_type) + c_bits, a_pack=pack_factor_a, b_pack=pack_factor_b, type_a=type_a, type_b=type_b, + scale_factor=scale_factor, scale_type=scale_type) - if elem_type == "e2m1": + if type_a == "e2m1": a = torch.tensor(a_bits, device="cuda", dtype=torch.uint8) + else: + torch_dtype_a = torch.float8_e4m3fn if type_a == "e4m3" else torch.float8_e5m2 + a = torch.tensor(a_bits, device="cuda", dtype=torch.uint8).view(torch_dtype_a) + if type_b == "e2m1": b = torch.tensor(b_bits, device="cuda", dtype=torch.uint8) else: - torch_dtype = torch.float8_e4m3fn if elem_type == "e4m3" else torch.float8_e5m2 - a = torch.tensor(a_bits, device="cuda", dtype=torch.uint8).view(torch_dtype) - b = torch.tensor(b_bits, device="cuda", dtype=torch.uint8).view(torch_dtype) - a_scale = torch.tensor(a_scale_bits, device="cuda", dtype=torch.int8) - b_scale = torch.tensor(b_scale_bits, device="cuda", dtype=torch.int8) + torch_dtype_b = torch.float8_e4m3fn if type_b == "e4m3" else torch.float8_e5m2 + b = torch.tensor(b_bits, device="cuda", dtype=torch.uint8).view(torch_dtype_b) + if scale_type == "e4m3": + a_scale = torch.tensor(a_scale_bits, device="cuda", dtype=torch.uint8).view(torch.float8_e4m3fn) + b_scale = torch.tensor(b_scale_bits, device="cuda", dtype=torch.uint8).view(torch.float8_e4m3fn) + scale_dtype = gl.float8e4nv + else: + a_scale = torch.tensor(a_scale_bits, device="cuda", dtype=torch.int8) + b_scale = torch.tensor(b_scale_bits, device="cuda", dtype=torch.int8) + scale_dtype = gl.int8 c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((B, B), device="cuda", dtype=torch.int32) + out = torch.empty((m, n), device="cuda", dtype=torch.int32) cw = triton.TensorWrapper(c, dtype=torch.float32) outw = triton.TensorWrapper(out, dtype=torch.float32) - kernel[(1, )](a, b, a_scale, b_scale, cw, outw, TYPE=elem_type) + kernel[(1, )](a, b, a_scale, b_scale, cw, outw, TYPE_A=type_a, TYPE_B=type_b, SCALE_DTYPE=scale_dtype) _assert_payload_equal(out, exp_bits) @@ -2090,7 +2355,7 @@ def kernel(a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, c_ptr, out_ptr): b_scale_bits = rs.randint(1, 4, size=(N, K // 32), dtype=np.int8) c_bits = rs.randint(-(2**31), 2**31 - 1, size=(M, N), dtype=np.int32) exp_bits = _mm_scaled_payload_u32(a_bits, b_bits.T, a_scale_bits.view(np.uint8), b_scale_bits.view(np.uint8), - c_bits, a_pack=1, b_pack=1, elem_type="e5m2") + c_bits, a_pack=1, b_pack=1, type_a="e5m2", type_b="e5m2") a = torch.tensor(a_bits, device="cuda", dtype=torch.uint8).view(torch.float8_e5m2) b = torch.tensor(b_bits, device="cuda", dtype=torch.uint8).view(torch.float8_e5m2) @@ -2360,20 +2625,18 @@ def loop_sum_kernel(x_ptr, out_ptr, N: tl.constexpr): @pytest.mark.skipif(not (is_hip_cdna3() or is_hip_cdna4()), reason="Requires CDNA3 or CDNA4") -def test_mfma_dot(device, fresh_knobs): +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k", "instr_m", "instr_n", "instr_k", "k_width"), + _MFMA_DOT_CASES) +def test_mfma_dot(device, type_a, type_b, acc_type, m, n, k, instr_m, instr_n, instr_k, k_width, fresh_knobs): _require_cuda_backend(device) - M, N, K = 16, 16, 32 - fresh_knobs.compilation.instrumentation_mode = "fpsan" cdna_version = 3 if is_hip_cdna3() else 4 - nonkdim = 32 - kdim = 8 if cdna_version == 3 else 16 - k_width_val = 4 if cdna_version == 3 else 8 blocked = gl.BlockedLayout([4, 4], [4, 16], [4, 1], [1, 0]) - mfma_layout = gl.amd.AMDMFMALayout(cdna_version, [nonkdim, nonkdim, kdim], True, [4, 1]) + mfma_layout = gl.amd.AMDMFMALayout(cdna_version, [instr_m, instr_n, instr_k], True, [4, 1], + element_bitwidth=_float_dtype_info(acc_type)[0]) @gluon.jit def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, BLOCK_K: gl.constexpr, @@ -2399,47 +2662,43 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK_M: gl.constexpr, BLOCK_N: gl.cons gl.store(out_ptr + offs_am[:, None] * BLOCK_N + offs_bn[None, :], result) rs = np.random.RandomState(0) - a_bits = rs.randint(-(2**31), 2**31 - 1, size=(M, K), dtype=np.int32) - b_bits = rs.randint(-(2**31), 2**31 - 1, size=(K, N), dtype=np.int32) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(M, N), dtype=np.int32) - exp_bits = _mm_payload_u32(a_bits, b_bits, c_bits) - - a = torch.tensor(a_bits, device="cuda", dtype=torch.int32) - b = torch.tensor(b_bits, device="cuda", dtype=torch.int32) - c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((M, N), device="cuda", dtype=torch.int32) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (k, n), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits, c_bits, type_a, type_b, acc_type) - aw = triton.TensorWrapper(a, dtype=torch.float32) - bw = triton.TensorWrapper(b, dtype=torch.float32) - cw = triton.TensorWrapper(c, dtype=torch.float32) - outw = triton.TensorWrapper(out, dtype=torch.float32) + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) - kernel[(1, )](aw, bw, cw, outw, BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, blocked=blocked, k_width=k_width_val, + kernel[(1, )](aw, bw, cw, outw, BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, blocked=blocked, k_width=k_width, mfma_layout=mfma_layout) _assert_payload_equal(out, exp_bits) @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires gfx1250") -def test_wmma_dot(device, fresh_knobs): +@pytest.mark.parametrize(("type_a", "type_b", "acc_type", "m", "n", "k", "instr_k", "k_width"), _WMMA_DOT_CASES) +def test_wmma_dot(device, type_a, type_b, acc_type, m, n, k, instr_k, k_width, fresh_knobs): _require_cuda_backend(device) - B = 32 fresh_knobs.compilation.instrumentation_mode = "fpsan" @gluon.jit - def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK: gl.constexpr, INSTR_SHAPE_K: gl.constexpr, K_WIDTH: gl.constexpr): + def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, BLOCK_K: gl.constexpr, + INSTR_SHAPE_K: gl.constexpr, K_WIDTH: gl.constexpr): blocked: gl.constexpr = gl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) wmma: gl.constexpr = gl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, INSTR_SHAPE_K]) - offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, blocked))[:, None] - offs_k = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, blocked))[None, :] - offs_bk = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, blocked))[:, None] - offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, blocked))[None, :] + offs_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, blocked))[:, None] + offs_k = gl.arange(0, BLOCK_K, layout=gl.SliceLayout(0, blocked))[None, :] + offs_bk = gl.arange(0, BLOCK_K, layout=gl.SliceLayout(1, blocked))[:, None] + offs_n = gl.arange(0, BLOCK_N, layout=gl.SliceLayout(0, blocked))[None, :] - a = gl.load(a_ptr + offs_m * BLOCK + offs_k) - b = gl.load(b_ptr + offs_bk * BLOCK + offs_n) - c = gl.load(c_ptr + offs_m * BLOCK + offs_n) + a = gl.load(a_ptr + offs_m * BLOCK_K + offs_k) + b = gl.load(b_ptr + offs_bk * BLOCK_N + offs_n) + c = gl.load(c_ptr + offs_m * BLOCK_N + offs_n) c = gl.convert_layout(c, wmma) a = gl.convert_layout(a, gl.DotOperandLayout(0, wmma, K_WIDTH)) @@ -2447,26 +2706,21 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr, BLOCK: gl.constexpr, INSTR_SHAPE_K: gl. acc = gl.amd.gfx1250.wmma(a, b, c) out_layout: gl.constexpr = gl.SliceLayout(1, wmma) - offs_cm = gl.arange(0, BLOCK, layout=out_layout)[:, None] - offs_cn = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, wmma))[None, :] - gl.store(out_ptr + offs_cm * BLOCK + offs_cn, acc) + offs_cm = gl.arange(0, BLOCK_M, layout=out_layout)[:, None] + offs_cn = gl.arange(0, BLOCK_N, layout=gl.SliceLayout(0, wmma))[None, :] + gl.store(out_ptr + offs_cm * BLOCK_N + offs_cn, acc) rs = np.random.RandomState(0) - a_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - b_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - c_bits = rs.randint(-(2**31), 2**31 - 1, size=(B, B), dtype=np.int32) - exp_bits = _mm_payload_u32(a_bits, b_bits, c_bits) - - a = torch.tensor(a_bits, device="cuda", dtype=torch.int32) - b = torch.tensor(b_bits, device="cuda", dtype=torch.int32) - c = torch.tensor(c_bits, device="cuda", dtype=torch.int32) - out = torch.empty((B, B), device="cuda", dtype=torch.int32) + a_bits = _random_float_bits(rs, (m, k), type_a) + b_bits = _random_float_bits(rs, (k, n), type_b) + c_bits = _random_float_bits(rs, (m, n), acc_type) + exp_bits = _mm_payload_bits(a_bits, b_bits, c_bits, type_a, type_b, acc_type) - aw = triton.TensorWrapper(a, dtype=torch.float32) - bw = triton.TensorWrapper(b, dtype=torch.float32) - cw = triton.TensorWrapper(c, dtype=torch.float32) - outw = triton.TensorWrapper(out, dtype=torch.float32) + _, aw = _as_float_bits_tensor(a_bits, type_a) + _, bw = _as_float_bits_tensor(b_bits, type_b) + _, cw = _as_float_bits_tensor(c_bits, acc_type) + out, outw = _as_float_bits_tensor(np.empty((m, n), dtype=_float_dtype_info(acc_type)[2]), acc_type) - kernel[(1, )](aw, bw, cw, outw, BLOCK=B, INSTR_SHAPE_K=4, K_WIDTH=2) + kernel[(1, )](aw, bw, cw, outw, BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, INSTR_SHAPE_K=instr_k, K_WIDTH=k_width) _assert_payload_equal(out, exp_bits)