Skip to content

Commit 391846f

Browse files
Nick Riasanovskyfacebook-github-bot
Nick Riasanovsky
authored andcommitted
Enable full autotuning for AMD
Summary: Enables full autotuning for AMD so we can test across the full suite of configurations. Requires setting an evironment variable. Reviewed By: PaulZhang12 Differential Revision: D70587961 fbshipit-source-id: 8fb0faffb49bb9fed7907e0ec66e0ba39fb3ee0f
1 parent a87ce40 commit 391846f

File tree

5 files changed

+109
-34
lines changed

5 files changed

+109
-34
lines changed

tritonbench/operators/gemm/kernels/matmul.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from triton import autotune, cdiv, Config, heuristics, jit, language as tl
66

7+
from ..triton_matmul_configs import get_full_amd_config_space, init_to_zero
8+
79
from .matmul_perf_model import early_config_prune, estimate_matmul_time
810

911
_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
@@ -31,10 +33,6 @@ def get_higher_dtype(a, b):
3133
return a
3234

3335

34-
def init_to_zero(name):
35-
return lambda nargs: nargs[name].zero_()
36-
37-
3836
def get_configs_io_bound():
3937
configs = []
4038
for num_stages in [2, 3, 4, 5, 6]:
@@ -85,9 +83,10 @@ def get_configs_io_bound():
8583
else {}
8684
)
8785

88-
89-
@autotune(
90-
configs=[
86+
if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None:
87+
tuning_configs = get_full_amd_config_space(True)
88+
else:
89+
tuning_configs = [
9190
# basic configs for compute-bound matmuls
9291
Config(
9392
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1, "GROUP_M": 8},
@@ -198,8 +197,11 @@ def get_configs_io_bound():
198197
num_stages=5,
199198
num_warps=2,
200199
),
201-
]
202-
+ get_configs_io_bound(),
200+
] + get_configs_io_bound()
201+
202+
203+
@autotune(
204+
configs=tuning_configs,
203205
key=["M", "N", "K"],
204206
prune_configs_by=prune_configs_by,
205207
)

tritonbench/operators/gemm/persistent_matmul.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from tritonbench.utils.env_utils import is_cuda
99
from tritonbench.utils.triton_op import IS_FBCODE
1010

11+
from .triton_matmul_configs import get_full_amd_config_space
12+
1113
if not IS_FBCODE:
1214
import triton.tools.experimental_descriptor
1315

@@ -96,8 +98,14 @@ def _matmul_launch_metadata(grid, kernel, args):
9698
return ret
9799

98100

101+
if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None:
102+
tuning_configs = get_full_amd_config_space(False)
103+
else:
104+
tuning_configs = persistent_matmul_configs()
105+
106+
99107
@triton.autotune(
100-
configs=persistent_matmul_configs(),
108+
configs=tuning_configs,
101109
key=["M", "N", "K"],
102110
)
103111
@triton.jit(launch_metadata=_matmul_launch_metadata)

tritonbench/operators/gemm/stream_k.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,33 @@
1515

1616
from tritonbench.utils.env_utils import is_hip_mi300
1717

18-
tuning_configs = [
19-
triton.Config(
20-
{
21-
"BLOCK_M": 128,
22-
"BLOCK_N": 128,
23-
"BLOCK_K": 64,
24-
"GROUP_M": 8,
25-
},
26-
num_stages=2,
27-
num_warps=8,
28-
),
29-
triton.Config(
30-
{
31-
"BLOCK_M": 64,
32-
"BLOCK_N": 64,
33-
"BLOCK_K": 128,
34-
"GROUP_M": 8,
35-
},
36-
num_stages=2,
37-
num_warps=8,
38-
),
39-
]
18+
from .triton_matmul_configs import get_full_amd_config_space
19+
20+
if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None:
21+
tuning_configs = get_full_amd_config_space(False)
22+
else:
23+
tuning_configs = [
24+
triton.Config(
25+
{
26+
"BLOCK_M": 128,
27+
"BLOCK_N": 128,
28+
"BLOCK_K": 64,
29+
"GROUP_M": 8,
30+
},
31+
num_stages=2,
32+
num_warps=8,
33+
),
34+
triton.Config(
35+
{
36+
"BLOCK_M": 64,
37+
"BLOCK_N": 64,
38+
"BLOCK_K": 128,
39+
"GROUP_M": 8,
40+
},
41+
num_stages=2,
42+
num_warps=8,
43+
),
44+
]
4045

4146

4247
@triton.autotune(

tritonbench/operators/gemm/triton_matmul.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
import triton
1111
import triton.language as tl
1212

13-
from .triton_matmul_configs import configs
13+
from .triton_matmul_configs import configs, get_full_amd_config_space
14+
15+
if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None:
16+
tuning_configs = get_full_amd_config_space(False)
17+
else:
18+
tuning_configs = configs
1419

1520

1621
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
@@ -19,7 +24,7 @@
1924
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
2025
# provided configs
2126
@triton.autotune(
22-
configs=configs,
27+
configs=tuning_configs,
2328
key=["M", "N", "K"],
2429
)
2530
@triton.jit

tritonbench/operators/gemm/triton_matmul_configs.py

+55
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23

34
import triton
@@ -259,3 +260,57 @@
259260
num_warps=2,
260261
),
261262
]
263+
264+
265+
def init_to_zero(name):
266+
return lambda nargs: nargs[name].zero_()
267+
268+
269+
def get_full_amd_config_space(use_splitk: bool):
270+
configs = []
271+
272+
block_mn_range = [16, 32, 64, 128, 256]
273+
block_k_range = [16, 32, 64, 128, 256]
274+
num_warps_range = [1, 2, 4, 8]
275+
group_m_range = [8]
276+
waves_per_eu_range = [0, 1, 2, 4]
277+
278+
for block_m in block_mn_range:
279+
for block_n in block_mn_range:
280+
for block_k in block_k_range:
281+
for num_warps in num_warps_range:
282+
for group_m in group_m_range:
283+
for waves_per_eu in waves_per_eu_range:
284+
base_config_dict = {
285+
"BLOCK_M": block_m,
286+
"BLOCK_N": block_n,
287+
"BLOCK_K": block_k,
288+
"GROUP_M": group_m,
289+
"waves_per_eu": waves_per_eu,
290+
"kpack": 2,
291+
}
292+
config_dicts = []
293+
if use_splitk:
294+
max_k_pow2 = np.int64(np.log2(block_k))
295+
split_k_range = [2**i for i in range(max_k_pow2)]
296+
for split_k in split_k_range:
297+
config_dicts.append(
298+
{
299+
**base_config_dict,
300+
"SPLIT_K": split_k,
301+
}
302+
)
303+
else:
304+
config_dicts.append(base_config_dict)
305+
for config_dict in config_dicts:
306+
configs.append(
307+
triton.Config(
308+
config_dict,
309+
num_warps=num_warps,
310+
num_stages=2,
311+
pre_hook=init_to_zero("C")
312+
if use_splitk
313+
else None,
314+
)
315+
)
316+
return configs

0 commit comments

Comments
 (0)