Skip to content

Commit b69a816

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
PartitionK fp32 accumulator + interleave loads
Summary: Background: [Adding SPLIT_K to Triton Templates](https://docs.google.com/document/d/1K1DwmVkzqoB_uDoWkPvmOa2xra5yko5JmQzy5Jxy-rg/edit?tab=t.0) This diff makes two changes to the partitionK kernel: 1. FP32 accumulator throughout 2. Interleave loads from the split-k dimension for coalesced memory accesses. **FP32 accumulator**: Previously, intra kernel we accumulated in fp32, but the intermediate buffer was fp16, meaning the k dimension would be reduced in fp16. This loss of precision could hurt accuracy. After discussing with sijiac and eellison, since we wanted the same correctness of cuBLAS/cutlass which uses fp32 accumulation throughout, we made this change in the kernel, though it hurt performance. **Interleave loads optimization**: Before if K = 4, and PK = 2, each kernel instance would process half of the K dimension sequentially, with kernel instance 0 processing K=0, 1 and kernel instance 1 processing K=2,3. Now, kernel instance 0 would process K=0, 2 and kernel instance 1 processes K=1,3. In the previous case, when K scales up, loads between instances are nowhere near each other, decreasing chance for coalesced memory loads and cache hits. With the new case, those chances increase as loads are closer to each other in memory. **Results**: By changing accumulation throughout to FP32, performance takes a big hit, going from ~94% of aten performance -> 81% on average across all shapes. However, with interleave loads, the FP32 accumulation performance improves to 86% of aten performance. Given traditional SPLIT_K performance in triton_ops_matmul is at ~91%, this is acceptable. DecomposeK has highest performance of ~96% of aten. This kernel could be further improved, potentially with TMA/persistent optimizations? Reviewed By: sijiac Differential Revision: D71437375 fbshipit-source-id: b79fe3022770c9b744e54bbbb73905695240c8fa
1 parent 6a8e573 commit b69a816

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tritonbench/operators/gemm/partition_k.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _matmul_partition_k(
144144
# See above `Pointer Arithmetic` section for details
145145
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
146146
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
147-
offs_k = (pid_pk * PK_SIZE + tl.arange(0, BLOCK_SIZE_K)) % K
147+
offs_k = (pid_pk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
148148
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
149149
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
150150

@@ -162,9 +162,8 @@ def _matmul_partition_k(
162162
a = tl.load(a_ptrs)
163163
b = tl.load(b_ptrs)
164164
accumulator += tl.dot(a, b)
165-
a_ptrs += BLOCK_SIZE_K * stride_ak
166-
b_ptrs += BLOCK_SIZE_K * stride_bk
167-
acc = accumulator.to(tl.float16)
165+
a_ptrs += PK_SIZE * stride_ak
166+
b_ptrs += PK_SIZE * stride_bk
168167

169168
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
170169
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -175,7 +174,7 @@ def _matmul_partition_k(
175174
+ stride_cb_n * offs_cn[None, :, None]
176175
+ stride_cb_k * offs_ck[None, None, :]
177176
)
178-
tl.store(c_buf_ptrs, acc[:, :, None])
177+
tl.store(c_buf_ptrs, accumulator[:, :, None])
179178

180179

181180
@triton.jit
@@ -228,7 +227,8 @@ def matmul_partition_k(a, b, triton_reduce=False):
228227
# Allocates output.
229228
partitionK_SIZE = K // partitionK
230229

231-
c_buf = torch.empty((M, N, partitionK), device=a.device, dtype=a.dtype)
230+
# Enforce accumulation in float32 for accuracy
231+
c_buf = torch.empty((M, N, partitionK), device=a.device, dtype=torch.float32)
232232
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
233233
# 1D launch kernel where each block gets its own program.
234234

@@ -276,4 +276,4 @@ def matmul_partition_k(a, b, triton_reduce=False):
276276
)
277277
return c
278278
else:
279-
return c_buf.sum(dim=2)
279+
return c_buf.sum(dim=2).to(a.dtype)

0 commit comments

Comments
 (0)