You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add L2 Cache optimization to PartitionK matmul kernel
Summary:
Implement the L2 cache optimization for PartitionK, which previously wasn't there. Furthermore, tune on the GROUP_SIZE_M.
Benchmark results from TritonBench are in the test plan based on separate partitionK values. Large partitionK values lead to accuracy issues + poor performance.
TODO: Currently unable to tune on partitionK parameter. Asked question in slack along the lines of the following:
> For context, I am working with a version of the partition-k matmul kernel on triton, which contains an intermediate tensor for storing the of each K before reduction, of shape (M, N, partitionK). I want to be able to autotune on this value of partitionK, allocating the intermediate buffer like the following:
```
def allocate_c_buf(nargs, **kwargs):
nargs["c_buf_ptr"] = torch.empty((nargs["M"], nargs["N"], nargs["PARTITION_SIZE_K"]), device=nargs["c_buf_ptr"].device, dtype=nargs["c_buf_ptr"].dtype)
```
> It seems like the Triton autotuner does not use the modification of the nargs dictionary: https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py#L151-L154, just the original args. Therefore, the pre_hook doesn't actually do anything here. What is the recommended approach to do this? Can the autotuner code be modified to take the modified version of full_nargs and run the kernel with that?
TLDR: Triton autotuning only supports modifying the tensor args in-place from what it seems like. There is no way to extend out the memory of a tensor in place.
Reviewed By: sijiac
Differential Revision: D71368870
fbshipit-source-id: 93312b1763317a670099f528aa6f369717d3104d
0 commit comments