Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import triton
from triton_kernels.target_info import get_cdna_version
from triton_kernels.target_info import get_cdna_version, get_rdna_version
from triton_kernels.tensor import FP4
import torch
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
Expand Down Expand Up @@ -82,6 +82,8 @@ def make_default_opt_flags_amd(
block_m = 256 if is_cdna4 else 128
elif is_cdna4 and m >= 512:
block_m = 128
elif get_rdna_version() in (3, 4) and m >= 512:
block_m = 64
else:
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import triton
from triton_kernels.target_info import get_cdna_version
from triton_kernels.target_info import get_cdna_version, get_rdna_version
from triton_kernels.tensor import bitwidth


Expand All @@ -23,11 +23,18 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi
if get_cdna_version() == 4 and block_m == 128:
block_n = 512

if get_rdna_version() in (3, 4) and block_m == 64:
block_n = 256

# block_k needs to match the cacheline size (128B)
block_k = int(128 // min(lhs_width, rhs_width))

# TODO: block_k = 128 seems to work better for now.
# perhaps due to increased number of k loops to pipeline
if precision_config.weight_scale is not None and get_cdna_version() != 4:
block_k = 128
if precision_config.weight_scale is not None:
if get_cdna_version() != 4:
block_k = 128

if get_rdna_version() in (3, 4) and block_m == 64:
block_k = 64
return block_n, block_k
18 changes: 18 additions & 0 deletions python/triton_kernels/triton_kernels/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
__all__ = [
"cuda_capability_geq",
"get_cdna_version",
"get_rdna_version",
"has_tma_gather",
"has_native_mxfp",
"is_cuda",
Expand Down Expand Up @@ -40,6 +41,23 @@ def get_cdna_version():
return -1


@triton.constexpr_function
def get_rdna_version():
"""
Gets the AMD architecture version, i.e. RDNA3 or RDNA4, by matching
gfx11* (RDNA3) or gfx12* (RDNA4). Returns -1 if it is not AMD
hardware or unsupported architecture.
"""
target = tl.target_info.current_target()
if target.backend != 'hip':
return -1
if target.arch.startswith('gfx11'):
return 3
if target.arch.startswith('gfx12') and not target.arch.startswith('gfx125'):
return 4
return -1


@triton.constexpr_function
def has_tma_gather():
return cuda_capability_geq(10, 0)
Expand Down
Loading