Skip to content

Commit fe0401a

Browse files
committed
address review comments
1 parent 38d0107 commit fe0401a

3 files changed

Lines changed: 43 additions & 5 deletions

File tree

tritonbench/kernels/triton_fused_attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
import triton
1919
import triton.language as tl
20+
from triton import knobs
21+
2022

2123
from .attention_utils import (
2224
HAS_EXPLICIT_WS, # guard new tuning configs such as num_consumer_groups
@@ -28,6 +30,20 @@
2830
)
2931

3032

33+
def is_cuda():
34+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
35+
36+
37+
def is_hip_async_copy_enabled():
38+
if is_cuda():
39+
return False
40+
41+
# default is enabled
42+
if knobs.amd.use_async_copy is None:
43+
return True
44+
return knobs.amd.use_async_copy
45+
46+
3147
if HAS_TMA_DESC:
3248
print(
3349
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
@@ -481,7 +497,7 @@ def get_fwd_config_space(
481497
bmList = [128] if enable_ws else [64, 128]
482498
bnList = [64, 128] # To handle hDim of 64, we need BLOCK_N to be <= 64
483499
wList = [4] if enable_ws else [4, 8]
484-
stageList = [2] if enable_ws else [3, 4, 7] if torch.version.hip is None else [3]
500+
stageList = [2] if enable_ws else [3] if is_hip_async_copy_enabled() else [3, 4, 7]
485501
for BM in bmList:
486502
for BN in bnList:
487503
for sched in schedList: # set in global scope

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import triton
6+
from triton import knobs
67
import triton.language as tl
78
from torch._inductor.kernel.mm import ScalingType
89
from tritonbench.utils.env_utils import is_cuda
@@ -24,6 +25,16 @@
2425
pass
2526

2627

28+
def is_hip_async_copy_enabled():
29+
if is_cuda():
30+
return False
31+
32+
# default is enabled
33+
if knobs.amd.use_async_copy is None:
34+
return True
35+
return knobs.amd.use_async_copy
36+
37+
2738
def _matmul_launch_metadata(grid, kernel, args):
2839
ret = {}
2940
M, N, K = args["M"], args["N"], args["K"]
@@ -135,23 +146,23 @@ def matmul_persistent(a, b):
135146
"BLOCK_SIZE_N": 256,
136147
"BLOCK_SIZE_K": 128,
137148
"GROUP_SIZE_M": 8,
138-
"num_stages": 4 if torch.version.hip is None else 3,
149+
"num_stages": 3 if is_hip_async_copy_enabled() else 4,
139150
"num_warps": 8,
140151
},
141152
torch.float16: {
142153
"BLOCK_SIZE_M": 128,
143154
"BLOCK_SIZE_N": 256,
144155
"BLOCK_SIZE_K": 64,
145156
"GROUP_SIZE_M": 8,
146-
"num_stages": 3 if torch.version.hip is None else 2,
157+
"num_stages": 2 if is_hip_async_copy_enabled() else 3,
147158
"num_warps": 8,
148159
},
149160
torch.bfloat16: {
150161
"BLOCK_SIZE_M": 128,
151162
"BLOCK_SIZE_N": 256,
152163
"BLOCK_SIZE_K": 64,
153164
"GROUP_SIZE_M": 8,
154-
"num_stages": 3 if torch.version.hip is None else 2,
165+
"num_stages": 2 if is_hip_async_copy_enabled() else 3,
155166
"num_warps": 8,
156167
},
157168
}

tritonbench/operators/grouped_gemm/kernels.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
import triton
3434
import triton.language as tl
35+
from triton import knobs
3536

3637
try:
3738
# @manual=//triton:triton
@@ -50,6 +51,16 @@ def num_sms():
5051
return torch.cuda.get_device_properties("cuda").multi_processor_count
5152

5253

54+
def is_hip_async_copy_enabled():
55+
if is_cuda():
56+
return False
57+
58+
# default is enabled
59+
if knobs.amd.use_async_copy is None:
60+
return True
61+
return knobs.amd.use_async_copy
62+
63+
5364
def torch_dtype_to_triton_dtype(dtype):
5465
if dtype == torch.float16:
5566
return tl.float16
@@ -72,7 +83,7 @@ def torch_dtype_to_triton_dtype(dtype):
7283
"BLOCK_SIZE_K": BLOCK_K,
7384
"NUM_SMS": num_sms(),
7485
},
75-
num_stages=3 if torch.version.hip is None else 2,
86+
num_stages=2 if is_hip_async_copy_enabled() else 3,
7687
)
7788
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product([128, 256], repeat=3)
7889
],

0 commit comments

Comments
 (0)