diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 53a79c3f5b92..08f7522d6267 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import triton -from triton_kernels.target_info import get_cdna_version +from triton_kernels.target_info import get_cdna_version, get_rdna_version from triton_kernels.tensor import FP4 import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia @@ -82,6 +82,8 @@ def make_default_opt_flags_amd( block_m = 256 if is_cdna4 else 128 elif is_cdna4 and m >= 512: block_m = 128 + elif get_rdna_version() in (3, 4) and m >= 512: + block_m = 64 else: block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64)) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py index ffe06c333f60..598cbbac823b 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py @@ -1,6 +1,6 @@ import torch import triton -from triton_kernels.target_info import get_cdna_version +from triton_kernels.target_info import get_cdna_version, get_rdna_version from triton_kernels.tensor import bitwidth @@ -23,11 +23,18 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi if get_cdna_version() == 4 and block_m == 128: block_n = 512 + if get_rdna_version() in (3, 4) and block_m == 64: + block_n = 256 + # block_k needs to match the cacheline size (128B) block_k = int(128 // min(lhs_width, rhs_width)) # TODO: block_k = 128 seems to work better for now. # perhaps due to increased number of k loops to pipeline - if precision_config.weight_scale is not None and get_cdna_version() != 4: - block_k = 128 + if precision_config.weight_scale is not None: + if get_cdna_version() != 4: + block_k = 128 + + if get_rdna_version() in (3, 4) and block_m == 64: + block_k = 64 return block_n, block_k diff --git a/python/triton_kernels/triton_kernels/target_info.py b/python/triton_kernels/triton_kernels/target_info.py index 4350efa8b78a..b597e64f5aba 100644 --- a/python/triton_kernels/triton_kernels/target_info.py +++ b/python/triton_kernels/triton_kernels/target_info.py @@ -13,6 +13,7 @@ __all__ = [ "cuda_capability_geq", "get_cdna_version", + "get_rdna_version", "has_tma_gather", "has_native_mxfp", "is_cuda", @@ -40,6 +41,23 @@ def get_cdna_version(): return -1 +@triton.constexpr_function +def get_rdna_version(): + """ + Gets the AMD architecture version, i.e. RDNA3 or RDNA4, by matching + gfx11* (RDNA3) or gfx12* (RDNA4). Returns -1 if it is not AMD + hardware or unsupported architecture. + """ + target = tl.target_info.current_target() + if target.backend != 'hip': + return -1 + if target.arch.startswith('gfx11'): + return 3 + if target.arch.startswith('gfx12') and not target.arch.startswith('gfx125'): + return 4 + return -1 + + @triton.constexpr_function def has_tma_gather(): return cuda_capability_geq(10, 0)