|
4 | 4 |
|
5 | 5 | from triton import autotune, cdiv, Config, heuristics, jit, language as tl
|
6 | 6 |
|
| 7 | +from ..triton_matmul_configs import get_full_amd_config_space, init_to_zero |
| 8 | + |
7 | 9 | from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
8 | 10 |
|
9 | 11 | _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
|
@@ -31,10 +33,6 @@ def get_higher_dtype(a, b):
|
31 | 33 | return a
|
32 | 34 |
|
33 | 35 |
|
34 |
| -def init_to_zero(name): |
35 |
| - return lambda nargs: nargs[name].zero_() |
36 |
| - |
37 |
| - |
38 | 36 | def get_configs_io_bound():
|
39 | 37 | configs = []
|
40 | 38 | for num_stages in [2, 3, 4, 5, 6]:
|
@@ -85,9 +83,10 @@ def get_configs_io_bound():
|
85 | 83 | else {}
|
86 | 84 | )
|
87 | 85 |
|
88 |
| - |
89 |
| -@autotune( |
90 |
| - configs=[ |
| 86 | +if os.environ.get("FULL_AUTOTUNING_AMD", "0") == "1" and torch.version.hip is not None: |
| 87 | + tuning_configs = get_full_amd_config_space(True) |
| 88 | +else: |
| 89 | + tuning_configs = [ |
91 | 90 | # basic configs for compute-bound matmuls
|
92 | 91 | Config(
|
93 | 92 | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1, "GROUP_M": 8},
|
@@ -198,8 +197,11 @@ def get_configs_io_bound():
|
198 | 197 | num_stages=5,
|
199 | 198 | num_warps=2,
|
200 | 199 | ),
|
201 |
| - ] |
202 |
| - + get_configs_io_bound(), |
| 200 | + ] + get_configs_io_bound() |
| 201 | + |
| 202 | + |
| 203 | +@autotune( |
| 204 | + configs=tuning_configs, |
203 | 205 | key=["M", "N", "K"],
|
204 | 206 | prune_configs_by=prune_configs_by,
|
205 | 207 | )
|
|
0 commit comments