-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmatmul.py
More file actions
139 lines (112 loc) · 5.5 KB
/
matmul.py
File metadata and controls
139 lines (112 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import triton
import triton.language as tl
DEVICE = "cuda"
autotune_configs = [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE': 8}, num_stages=5, num_warps=2)
]
@triton.autotune(configs = autotune_configs, key=['M', 'N', 'K'])
@triton.jit
def _matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_a_M, stride_a_K,
stride_b_K, stride_b_N,
stride_c_M, stride_c_N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
PID = tl.program_id(axis=0)
num_PID_along_M = tl.cdiv(M, BLOCK_SIZE_M)
num_PID_along_N = tl.cdiv(N, BLOCK_SIZE_N)
num_PID_in_group = GROUP_SIZE * num_PID_along_N
group_id = PID // num_PID_in_group
first_PID_in_group_along_M = group_id * GROUP_SIZE
group_size_adj = min(num_PID_along_M - first_PID_in_group_along_M, GROUP_SIZE)
PID_M = first_PID_in_group_along_M + ((PID % num_PID_in_group) % group_size_adj)
PID_N = (PID % num_PID_in_group) // group_size_adj
offsets_M = PID_M * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offsets_N = PID_N * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offsets_K = tl.arange(0, BLOCK_SIZE_K)
a_offsets = offsets_M[:, None] * stride_a_M + offsets_K[None, :] * stride_a_K
b_offsets = offsets_K[:, None] * stride_b_K + offsets_N[None, :] * stride_b_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
mask = offsets_K < K - k * BLOCK_SIZE_K
a = tl.load(a_ptr + a_offsets, mask=mask[None, :], other=0.0)
b = tl.load(b_ptr + b_offsets, mask=mask[:, None], other=0.0)
accumulator = tl.dot(a, b, acc=accumulator)
a_offsets += BLOCK_SIZE_K * stride_a_K
b_offsets += BLOCK_SIZE_K * stride_b_K
accumulator = accumulator.to(tl.float16)
c_offsets = stride_c_M * offsets_M[:, None] + stride_c_N * offsets_N[None, :]
c_mask = (offsets_M[:, None] < M) & (offsets_N[None, :] < N)
tl.store(c_ptr + c_offsets, accumulator, mask=c_mask)
def matmul(a, b):
# checking constraints.
# 1. to check if both are matrices and not vectors or tensors.
# 2. if matmul compatible dimensions or not
assert a.ndim == b.ndim == 2, "input not matrix."
assert a.shape[1] == b.shape[0], "incompatible matmul dimensions."
a,b = a.to(torch.float16), b.to(torch.float16)
(M, K), (_, N) = a.shape, b.shape
# allocating output
c = torch.empty((M, N), device = a.device, dtype = torch.float16)
# grid launch
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), )
_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
def test_matmul_kernel(size: tuple, atol=1e-2, rtol=1e-1, device=DEVICE):
# create input data
torch.manual_seed(0)
assert type(size) == tuple and len(size) == 2
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
# run kernel & pytorch reference implementation
c_tri = matmul(a, b)
c_ref = torch.matmul(a, b)
# compare
torch.testing.assert_close(c_tri, c_ref, atol=atol, rtol=rtol)
print("PASSED")
configs = [
triton.testing.Benchmark(
x_names = ["M", "N", "K"], # we can increase multiple dimensions simultaneously while benchmarking
x_vals = [128 * i for i in range(2, 33)],
line_arg = "provider",
line_vals = ["torch", "triton"],
line_names = ["PyTorch", "Triton"],
styles = [("green", "-"), ("blue", "-")],
ylabel = "TFLOPS",
plot_name = "matmul-performance",
args={},
)
]
@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
quantiles = [0.5, 0.05, 0.95]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf = lambda ms: 3 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
if __name__ == "__main__":
test_matmul_kernel(size=(1024, 1024))
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--benchmark":
benchmark.run(save_path='.', print_data=False)