You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Background: [Adding SPLIT_K to Triton Templates](https://docs.google.com/document/d/1K1DwmVkzqoB_uDoWkPvmOa2xra5yko5JmQzy5Jxy-rg/edit?tab=t.0)
This diff makes two changes to the partitionK kernel:
1. FP32 accumulator throughout
2. Interleave loads from the split-k dimension for coalesced memory accesses.
**FP32 accumulator**:
Previously, intra kernel we accumulated in fp32, but the intermediate buffer was fp16, meaning the k dimension would be reduced in fp16. This loss of precision could hurt accuracy. After discussing with sijiac and eellison, since we wanted the same correctness of cuBLAS/cutlass which uses fp32 accumulation throughout, we made this change in the kernel, though it hurt performance.
**Interleave loads optimization**:
Before if K = 4, and PK = 2, each kernel instance would process half of the K dimension sequentially, with kernel instance 0 processing K=0, 1 and kernel instance 1 processing K=2,3. Now, kernel instance 0 would process K=0, 2 and kernel instance 1 processes K=1,3.
In the previous case, when K scales up, loads between instances are nowhere near each other, decreasing chance for coalesced memory loads and cache hits. With the new case, those chances increase as loads are closer to each other in memory.
**Results**:
By changing accumulation throughout to FP32, performance takes a big hit, going from ~94% of aten performance -> 81% on average across all shapes. However, with interleave loads, the FP32 accumulation performance improves to 86% of aten performance. Given traditional SPLIT_K performance in triton_ops_matmul is at ~91%, this is acceptable. DecomposeK has highest performance of ~96% of aten.
This kernel could be further improved, potentially with TMA/persistent optimizations?
Reviewed By: sijiac
Differential Revision: D71437375
fbshipit-source-id: b79fe3022770c9b744e54bbbb73905695240c8fa
0 commit comments