Skip to content

Commit 69d425c

Browse files
sijiacfacebook-github-bot
authored andcommitted
Prototype
Differential Revision: D61055780
1 parent a8ce4b5 commit 69d425c

File tree

3 files changed

+295
-0
lines changed

3 files changed

+295
-0
lines changed

torchbenchmark/operators/fused_ffn/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
8+
@triton.autotune(
9+
configs=[
10+
triton.Config(
11+
# B_T, H_D (8192), D (2048)
12+
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K},
13+
num_stages=num_stages,
14+
num_warps=num_warps,
15+
)
16+
for BLOCK_M in [32]
17+
for BLOCK_N in [128, 256]
18+
for BLOCK_K in [128, 256]
19+
for num_stages in [2, 4, 8]
20+
for num_warps in [8]
21+
],
22+
key=["B_T", "D", "H_D"],
23+
)
24+
@triton.jit
25+
def fused_ffn_fwd(
26+
x_ptr,
27+
w13_ptr,
28+
w2_ptr,
29+
output_ptr,
30+
p_ptr,
31+
B_T,
32+
stride_xa,
33+
stride_xb,
34+
stride_w13a,
35+
stride_w13b,
36+
stride_w2a,
37+
stride_w2b,
38+
stride_oa,
39+
stride_ob,
40+
stride_pa,
41+
stride_pb,
42+
D: tl.constexpr,
43+
H_D: tl.constexpr,
44+
BLOCK_M: tl.constexpr,
45+
BLOCK_N: tl.constexpr,
46+
BLOCK_K: tl.constexpr,
47+
):
48+
pid_m = tl.program_id(axis=0)
49+
dtype = x_ptr.dtype.element_ty
50+
51+
X_block_ptr = tl.make_block_ptr(
52+
base=x_ptr,
53+
shape=(B_T, D),
54+
strides=(stride_xa, stride_xb),
55+
offsets=(pid_m * BLOCK_M, 0),
56+
block_shape=(BLOCK_M, BLOCK_K),
57+
order=(1, 0),
58+
)
59+
O_block_ptr = tl.make_block_ptr(
60+
base=output_ptr,
61+
shape=(B_T, D),
62+
strides=(stride_oa, stride_ob),
63+
offsets=(pid_m * BLOCK_M, 0),
64+
block_shape=(BLOCK_M, BLOCK_K),
65+
order=(1, 0),
66+
)
67+
68+
for start_n in range(0, H_D, BLOCK_N):
69+
P_block_ptr = tl.make_block_ptr(
70+
base=p_ptr,
71+
shape=(B_T, H_D),
72+
strides=(stride_pa, stride_pb),
73+
offsets=(pid_m * BLOCK_M, start_n),
74+
block_shape=(BLOCK_M, BLOCK_N),
75+
order=(1, 0),
76+
)
77+
w1t_bptr = tl.make_block_ptr(
78+
base=w13_ptr,
79+
shape=(D, H_D),
80+
strides=(stride_w13b, stride_w13a),
81+
offsets=(0, start_n),
82+
block_shape=(BLOCK_K, BLOCK_N),
83+
order=(0, 1),
84+
)
85+
w3t_bptr = tl.make_block_ptr(
86+
base=w13_ptr,
87+
shape=(D, H_D),
88+
strides=(stride_w13b, stride_w13a),
89+
offsets=(0, H_D + start_n),
90+
block_shape=(BLOCK_K, BLOCK_N),
91+
order=(0, 1),
92+
)
93+
w2_bptr = tl.make_block_ptr(
94+
base=w2_ptr,
95+
shape=(H_D, D),
96+
strides=(stride_w2a, stride_w2b),
97+
offsets=(0, 0),
98+
block_shape=(BLOCK_N, BLOCK_K),
99+
order=(1, 0),
100+
)
101+
102+
x_bptr = X_block_ptr
103+
o_bptr = O_block_ptr
104+
acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
105+
acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
106+
# first GEMM
107+
w1t_bptr_inner = w1t_bptr
108+
w3t_bptr_inner = w3t_bptr
109+
w2_bptr_inner = w2_bptr
110+
for _ in range(0, D, BLOCK_K):
111+
x = tl.load(x_bptr)
112+
w1t = tl.load(w1t_bptr_inner)
113+
w3t = tl.load(w3t_bptr_inner)
114+
acc_1 = tl.dot(x, w1t, acc_1)
115+
acc_3 = tl.dot(x, w3t, acc_3)
116+
x_bptr = tl.advance(x_bptr, (0, BLOCK_K))
117+
w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0))
118+
w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0))
119+
# acc_1 = acc_1.to(dtype).to(tl.float32)
120+
# acc_3 = acc_3.to(dtype).to(tl.float32)
121+
p = acc_1 * tl.sigmoid(acc_1) * acc_3
122+
p = p.to(dtype)
123+
tl.store(P_block_ptr, p)
124+
# second GEMM
125+
for _ in range(0, BLOCK_K, BLOCK_K):
126+
w2 = tl.load(w2_bptr)
127+
o = tl.load(o_bptr)
128+
tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype))
129+
w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K))
130+
o_bptr = tl.advance(o_bptr, (0, BLOCK_K))
131+
132+
133+
def fused_ffn(
134+
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor
135+
) -> Tuple[torch.Tensor, torch.Tensor]:
136+
# x: [B_T, D]
137+
# w13: [H_D*2, D]
138+
# w2: [H_D, D]
139+
# output: [B_T, D]
140+
B_T, D = x.shape
141+
H_D_2, D = w13.shape
142+
H_D = w2.shape[0]
143+
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}"
144+
145+
def grid(META):
146+
return (triton.cdiv(B_T, META["BLOCK_M"]),)
147+
148+
output = torch.empty_like(x)
149+
p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device)
150+
151+
fused_ffn_fwd[grid](
152+
x,
153+
w13,
154+
w2,
155+
output,
156+
p,
157+
B_T,
158+
x.stride(0),
159+
x.stride(1),
160+
w13.stride(0),
161+
w13.stride(1),
162+
w2.stride(0),
163+
w2.stride(1),
164+
output.stride(0),
165+
output.stride(1),
166+
p.stride(0),
167+
p.stride(1),
168+
D,
169+
H_D,
170+
)
171+
172+
return output
173+
174+
175+
@triton.jit
176+
# pyre-fixme[3]: Return type must be annotated.
177+
def _silu_mul_kernel(
178+
# pyre-fixme[2]: Parameter must be annotated.
179+
x1_ptr,
180+
x1_stride: tl.constexpr,
181+
# pyre-fixme[2]: Parameter must be annotated.
182+
x2_ptr,
183+
x2_stride: tl.constexpr,
184+
# pyre-fixme[2]: Parameter must be annotated.
185+
y_ptr,
186+
D: tl.constexpr,
187+
BLOCK_SIZE: tl.constexpr,
188+
):
189+
b = tl.program_id(0).to(tl.int64)
190+
191+
x1_start = x1_ptr + b * x1_stride
192+
x2_start = x2_ptr + b * x2_stride
193+
y_start = y_ptr + b * D
194+
195+
for offset in range(0, D, BLOCK_SIZE):
196+
cols = offset + tl.arange(0, BLOCK_SIZE)
197+
mask = cols < D
198+
x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32)
199+
x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32)
200+
yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16)
201+
tl.store(y_start + cols, yv, mask=mask)
202+
203+
204+
sigmoid = torch.nn.Sigmoid()
205+
206+
207+
def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
208+
assert x1.shape == x2.shape
209+
(B_T, D) = x1.shape
210+
out = torch.empty_like(x1)
211+
assert x1.stride(1) == x2.stride(1) == 1
212+
assert out.is_contiguous()
213+
grid = (B_T,)
214+
_silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024)
215+
return out
216+
217+
218+
def _ffn(x, w13, w2):
219+
p = x @ w13.T
220+
H_D_2, D = w13.shape
221+
H_D = H_D_2 // 2
222+
p1 = p[:, :H_D] # B_T, H_D
223+
p2 = p[:, H_D:] # B_T, H_D
224+
p_out = silu_mul(p1, p2) # B_T, H_D
225+
out = p_out @ w2
226+
return out
227+
228+
229+
def nunerics_check(shape):
230+
B_T, H_D, D = shape
231+
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
232+
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
233+
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
234+
triton_out, triton_p = fused_ffn(x, w13, w2)
235+
eager_out, eager_p, ref_p = _ffn(x, w13, w2)
236+
237+
print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=0))
238+
print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0))
239+
# print(triton_p[-1])
240+
# print(eager_p[-1])
241+
# print(ref_p[-1])
242+
243+
244+
def do_benchmark():
245+
246+
D = 2048
247+
H_D = 8192
248+
249+
configs = []
250+
configs.append(
251+
triton.testing.Benchmark(
252+
x_names=[
253+
"B_T",
254+
"H_D",
255+
"D",
256+
], # Argument names to use as an x-axis for the plot
257+
x_vals=[
258+
(i, H_D, D) for H_D, D in [(128, 256), (1024, 512), (8192, 2048)] for i in [1024, 2048, 4096, 8192, 16384]
259+
], # Different possible values for `x_name`
260+
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
261+
# Possible values for `line_arg`
262+
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
263+
line_vals=["eager", "fused"],
264+
line_names=["Eager", "Fused"],
265+
styles=[("green", "-"), ("blue", "-")],
266+
ylabel="Latency(ms)", # Label name for the y-axis
267+
plot_name="fused_ffn-benchmark",
268+
args={},
269+
)
270+
)
271+
272+
@triton.testing.perf_report(configs)
273+
def benchmark(B_T, H_D, D, provider):
274+
# breakpoint()
275+
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
276+
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
277+
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
278+
quantiles = [0.5, 0.2, 0.8]
279+
if provider == "eager":
280+
return triton.testing.do_bench(
281+
lambda: _ffn(x, w13, w2), quantiles=quantiles
282+
)
283+
if provider == "fused":
284+
return triton.testing.do_bench(
285+
lambda: fused_ffn(x, w13, w2), quantiles=quantiles
286+
)
287+
288+
benchmark.run(show_plots=True, print_data=True)
289+
290+
291+
if __name__ == "__main__":
292+
# B_T, H_D, D
293+
# nunerics_check((16, 128, 128))
294+
# nunerics_check((256, 8192, 2048))
295+
do_benchmark()

torchbenchmark/operators/fused_ffn/operator.py

Whitespace-only changes.

0 commit comments

Comments
 (0)