Skip to content

Commit 7e3982d

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
torch.compile reduction + cast
Summary: torch.compile the reduction + cast for fusing the 2 kernels. PartitionK now performs better than cuBLAS Reviewed By: sijiac Differential Revision: D71483304 fbshipit-source-id: 060d4a8c1be2fe3487876c6b7993e50e7d90fc1a
1 parent b69a816 commit 7e3982d

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tritonbench/operators/gemm/partition_k.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,13 @@ def _reduce(
213213
tl.store(c_ptrs, reduced_k)
214214

215215

216+
def torch_reduction(c_buf, a):
217+
return c_buf.sum(dim=2).to(a.dtype)
218+
219+
220+
compiled_reduction = torch.compile(torch_reduction)
221+
222+
216223
def matmul_partition_k(a, b, triton_reduce=False):
217224
# Check constraints.
218225
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
@@ -276,4 +283,4 @@ def matmul_partition_k(a, b, triton_reduce=False):
276283
)
277284
return c
278285
else:
279-
return c_buf.sum(dim=2).to(a.dtype)
286+
return compiled_reduction(c_buf, a)

0 commit comments

Comments
 (0)