Skip to content

Commit 62d4b4f

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Add L2 Cache optimization to PartitionK matmul kernel
Summary: Implement the L2 cache optimization for PartitionK, which previously wasn't there. Furthermore, tune on the GROUP_SIZE_M. Benchmark results from TritonBench are in the test plan based on separate partitionK values. Large partitionK values lead to accuracy issues + poor performance. TODO: Currently unable to tune on partitionK parameter. Asked question in slack along the lines of the following: > For context, I am working with a version of the partition-k matmul kernel on triton, which contains an intermediate tensor for storing the of each K before reduction, of shape (M, N, partitionK). I want to be able to autotune on this value of partitionK, allocating the intermediate buffer like the following: ``` def allocate_c_buf(nargs, **kwargs): nargs["c_buf_ptr"] = torch.empty((nargs["M"], nargs["N"], nargs["PARTITION_SIZE_K"]), device=nargs["c_buf_ptr"].device, dtype=nargs["c_buf_ptr"].dtype) ``` > It seems like the Triton autotuner does not use the modification of the nargs dictionary: https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py#L151-L154, just the original args. Therefore, the pre_hook doesn't actually do anything here. What is the recommended approach to do this? Can the autotuner code be modified to take the modified version of full_nargs and run the kernel with that? TLDR: Triton autotuning only supports modifying the tensor args in-place from what it seems like. There is no way to extend out the memory of a tensor in place. Reviewed By: sijiac Differential Revision: D71368870 fbshipit-source-id: 93312b1763317a670099f528aa6f369717d3104d
1 parent e5be656 commit 62d4b4f

File tree

1 file changed

+45
-22
lines changed

1 file changed

+45
-22
lines changed

tritonbench/operators/gemm/partition_k.py

+45-22
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import triton.language as tl
55

66

7-
@triton.autotune(
8-
configs=[
7+
def get_mm_configs():
8+
configs = [
99
triton.Config(
1010
{
1111
"BLOCK_SIZE_M": 32,
@@ -60,7 +60,27 @@
6060
num_stages=6,
6161
num_warps=2,
6262
),
63-
],
63+
]
64+
65+
partition_k_configs = []
66+
for config in configs:
67+
for GROUP_SIZE_M in [1, 4, 8]:
68+
partition_k_configs.append(
69+
triton.Config(
70+
{
71+
**config.kwargs,
72+
"GROUP_SIZE_M": GROUP_SIZE_M,
73+
},
74+
num_stages=config.num_stages,
75+
num_warps=config.num_warps,
76+
)
77+
)
78+
79+
return partition_k_configs
80+
81+
82+
@triton.autotune(
83+
configs=get_mm_configs(),
6484
key=["M", "N", "K", "PK"],
6585
)
6686
@triton.jit
@@ -89,6 +109,7 @@ def _matmul_partition_k(
89109
BLOCK_SIZE_M: tl.constexpr,
90110
BLOCK_SIZE_N: tl.constexpr,
91111
BLOCK_SIZE_K: tl.constexpr, #
112+
GROUP_SIZE_M: tl.constexpr,
92113
):
93114
"""Kernel for computing the matmul C = A x B.
94115
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -97,21 +118,22 @@ def _matmul_partition_k(
97118
# Map program ids `pid` to the block of C it should compute.
98119
# This is done in a grouped ordering to promote L2 data reuse.
99120
# See above `L2 Cache Optimizations` section for details.
100-
pid_m = tl.program_id(axis=0)
101-
pid_n = tl.program_id(axis=1)
102-
pid_pk = tl.program_id(axis=2)
103-
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
104-
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105-
# num_pid_pk = PK
106-
# num_pid_nk = num_pid_n * num_pid_pk
107-
# num_pid_in_group = GROUP_SIZE_M * num_pid_nk
108-
# group_id = pid // num_pid_in_group
109-
# first_pid_m = group_id * GROUP_SIZE_M
110-
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
111-
# pid_m = first_pid_m + (pid % group_size_m)
112-
# pid_nk = (pid % num_pid_in_group) // group_size_m
113-
# pid_n = pid_nk // num_pid_n
114-
# pid_pk = pid_nk % num_pid_n
121+
# pid_m = tl.program_id(axis=0)
122+
# pid_n = tl.program_id(axis=1)
123+
# pid_pk = tl.program_id(axis=2)
124+
pid = tl.program_id(0)
125+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
126+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
127+
num_pid_pk = PK
128+
num_pid_nk = num_pid_n * num_pid_pk
129+
num_pid_in_group = GROUP_SIZE_M * num_pid_nk
130+
group_id = pid // num_pid_in_group
131+
first_pid_m = group_id * GROUP_SIZE_M
132+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
133+
pid_m = first_pid_m + (pid % group_size_m)
134+
pid_nk = (pid % num_pid_in_group) // group_size_m
135+
pid_n = pid_nk // num_pid_pk
136+
pid_pk = pid_nk % num_pid_pk
115137

116138
# ----------------------------------------------------------
117139
# Create pointers for the first blocks of A and B.
@@ -198,7 +220,8 @@ def matmul_partition_k(a, b, triton_reduce=False):
198220
assert a.is_contiguous(), "Matrix A must be contiguous"
199221
assert b.is_contiguous(), "Matrix B must be contiguous"
200222

201-
partitionK = 64
223+
# TODO: Tune on this parameter, currently 32 is best performing
224+
partitionK = 32
202225

203226
M, K = a.shape
204227
K, N = b.shape
@@ -210,9 +233,9 @@ def matmul_partition_k(a, b, triton_reduce=False):
210233
# 1D launch kernel where each block gets its own program.
211234

212235
grid = lambda META: (
213-
triton.cdiv(M, META["BLOCK_SIZE_M"]),
214-
triton.cdiv(N, META["BLOCK_SIZE_N"]),
215-
partitionK,
236+
triton.cdiv(M, META["BLOCK_SIZE_M"])
237+
* triton.cdiv(N, META["BLOCK_SIZE_N"])
238+
* partitionK,
216239
)
217240
_matmul_partition_k[grid](
218241
a,

0 commit comments

Comments
 (0)