-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path2D_Convolution_Triton.py
More file actions
50 lines (39 loc) · 1.33 KB
/
2D_Convolution_Triton.py
File metadata and controls
50 lines (39 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import triton
import triton.language as tl
@triton.jit
def conv2d_kernel(
x_ptr, w_ptr, y_ptr,
H, W, KH, KW,
x_sy, x_sx, w_sy, w_sx, y_sy, y_sx,
TILE_H: tl.constexpr, TILE_W: tl.constexpr
):
pid_h = tl.program_id(0)
pid_w = tl.program_id(1)
out_H = H - KH + 1
out_W = W - KW + 1
oh = pid_h * TILE_H + tl.arange(0, TILE_H)
ow = pid_w * TILE_W + tl.arange(0, TILE_W)
mask_hw = (oh[:, None] < out_H) & (ow[None, :] < out_W)
base = x_ptr + oh[:, None] * x_sy + ow[None, :] * x_sx
acc = tl.zeros((TILE_H, TILE_W), dtype=tl.float32)
for kh in range(KH):
for kw in range(KW):
k_val = tl.load(w_ptr + kh * w_sy + kw * w_sx)
x_val = tl.load(base + kh * x_sy + kw * x_sx, mask=mask_hw, other=0.0)
acc += k_val * x_val
tl.store(y_ptr + oh[:, None] * y_sy + ow[None, :] * y_sx, acc, mask=mask_hw)
def conv2d_forward(x, w, y, H, W, KH, KW):
x = x.reshape(H, W)
w = w.reshape(KH, KW)
y = y.reshape(H - KH + 1, W - KW + 1)
x_sy, x_sx = x.stride()
w_sy, w_sx = w.stride()
y_sy, y_sx = y.stride()
TILE_H, TILE_W = 8, 8
grid = (triton.cdiv(y.shape[0], TILE_H), triton.cdiv(y.shape[1], TILE_W))
conv2d_kernel[grid](
x, w, y, H, W, KH, KW,
x_sy, x_sx, w_sy, w_sx, y_sy, y_sx,
TILE_H, TILE_W
)