Skip to content

Commit ebb212c

Browse files
sijiacfacebook-github-bot
authored andcommitted
Prototype (#2486)
Summary: Pull Request resolved: #2486 Differential Revision: D61055780
1 parent a8ce4b5 commit ebb212c

File tree

3 files changed

+427
-0
lines changed

3 files changed

+427
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.autotune(
8+
configs=[
9+
triton.Config(
10+
# B_T, H_D (8192), D (2048)
11+
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K},
12+
num_stages=num_stages,
13+
num_warps=num_warps,
14+
)
15+
for BLOCK_M in [64]
16+
for BLOCK_N in [128]
17+
for BLOCK_K in [128, 256]
18+
for num_stages in [2]
19+
for num_warps in [8]
20+
],
21+
key=["B_T", "D", "H_D"],
22+
)
23+
@triton.jit
24+
def fused_ffn_fwd(
25+
x_ptr,
26+
w13_ptr,
27+
w2_ptr,
28+
output_ptr,
29+
p_ptr,
30+
B_T,
31+
stride_xa,
32+
stride_xb,
33+
stride_w13a,
34+
stride_w13b,
35+
stride_w2a,
36+
stride_w2b,
37+
stride_oa,
38+
stride_ob,
39+
stride_pa,
40+
stride_pb,
41+
HAS_P: tl.constexpr,
42+
D: tl.constexpr,
43+
H_D: tl.constexpr,
44+
BLOCK_M: tl.constexpr,
45+
BLOCK_N: tl.constexpr,
46+
BLOCK_K: tl.constexpr,
47+
):
48+
pid_m = tl.program_id(axis=0)
49+
dtype = x_ptr.dtype.element_ty
50+
51+
X_block_ptr = tl.make_block_ptr(
52+
base=x_ptr,
53+
shape=(B_T, D),
54+
strides=(stride_xa, stride_xb),
55+
offsets=(pid_m * BLOCK_M, 0),
56+
block_shape=(BLOCK_M, BLOCK_K),
57+
order=(1, 0),
58+
)
59+
O_block_ptr = tl.make_block_ptr(
60+
base=output_ptr,
61+
shape=(B_T, D),
62+
strides=(stride_oa, stride_ob),
63+
offsets=(pid_m * BLOCK_M, 0),
64+
block_shape=(BLOCK_M, BLOCK_K),
65+
order=(1, 0),
66+
)
67+
68+
for start_n in range(0, H_D, BLOCK_N):
69+
if HAS_P:
70+
P_block_ptr = tl.make_block_ptr(
71+
base=p_ptr,
72+
shape=(B_T, H_D),
73+
strides=(stride_pa, stride_pb),
74+
offsets=(pid_m * BLOCK_M, start_n),
75+
block_shape=(BLOCK_M, BLOCK_N),
76+
order=(1, 0),
77+
)
78+
else:
79+
P_block_ptr = None
80+
81+
w1t_bptr = tl.make_block_ptr(
82+
base=w13_ptr,
83+
shape=(D, H_D),
84+
strides=(stride_w13b, stride_w13a),
85+
offsets=(0, start_n),
86+
block_shape=(BLOCK_K, BLOCK_N),
87+
order=(0, 1),
88+
)
89+
w3t_bptr = tl.make_block_ptr(
90+
base=w13_ptr,
91+
shape=(D, H_D),
92+
strides=(stride_w13b, stride_w13a),
93+
offsets=(0, H_D + start_n),
94+
block_shape=(BLOCK_K, BLOCK_N),
95+
order=(0, 1),
96+
)
97+
w2_bptr = tl.make_block_ptr(
98+
base=w2_ptr,
99+
shape=(H_D, D),
100+
strides=(stride_w2a, stride_w2b),
101+
offsets=(0, 0),
102+
block_shape=(BLOCK_N, BLOCK_K),
103+
order=(1, 0),
104+
)
105+
106+
x_bptr = X_block_ptr
107+
o_bptr = O_block_ptr
108+
acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
109+
acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
110+
# first GEMM
111+
w1t_bptr_inner = w1t_bptr
112+
w3t_bptr_inner = w3t_bptr
113+
w2_bptr_inner = w2_bptr
114+
for _ in range(0, D, BLOCK_K):
115+
x = tl.load(x_bptr)
116+
w1t = tl.load(w1t_bptr_inner)
117+
w3t = tl.load(w3t_bptr_inner)
118+
acc_1 = tl.dot(x, w1t, acc_1)
119+
acc_3 = tl.dot(x, w3t, acc_3)
120+
x_bptr = tl.advance(x_bptr, (0, BLOCK_K))
121+
w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0))
122+
w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0))
123+
# acc_1 = acc_1.to(dtype).to(tl.float32)
124+
# acc_3 = acc_3.to(dtype).to(tl.float32)
125+
p = acc_1 * tl.sigmoid(acc_1) * acc_3
126+
p = p.to(dtype)
127+
if HAS_P:
128+
tl.store(P_block_ptr, p)
129+
# second GEMM
130+
for _ in range(0, BLOCK_K, BLOCK_K):
131+
w2 = tl.load(w2_bptr)
132+
o = tl.load(o_bptr)
133+
tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype))
134+
w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K))
135+
o_bptr = tl.advance(o_bptr, (0, BLOCK_K))
136+
137+
138+
def fused_ffn(
139+
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False
140+
):
141+
# x: [B_T, D]
142+
# w13: [H_D*2, D]
143+
# w2: [H_D, D]
144+
# output: [B_T, D]
145+
B_T, D = x.shape
146+
H_D_2, D = w13.shape
147+
H_D = w2.shape[0]
148+
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}"
149+
150+
def grid(META):
151+
return (triton.cdiv(B_T, META["BLOCK_M"]),)
152+
153+
output = torch.empty_like(x)
154+
if has_p:
155+
p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device)
156+
else:
157+
p = None
158+
159+
fused_ffn_fwd[grid](
160+
x,
161+
w13,
162+
w2,
163+
output,
164+
p,
165+
B_T,
166+
x.stride(0),
167+
x.stride(1),
168+
w13.stride(0),
169+
w13.stride(1),
170+
w2.stride(0),
171+
w2.stride(1),
172+
output.stride(0),
173+
output.stride(1),
174+
p.stride(0) if has_p else 0,
175+
p.stride(1) if has_p else 0,
176+
has_p,
177+
D,
178+
H_D,
179+
)
180+
181+
return output, p
182+
183+
184+
@triton.jit
185+
# pyre-fixme[3]: Return type must be annotated.
186+
def _silu_mul_kernel(
187+
# pyre-fixme[2]: Parameter must be annotated.
188+
x1_ptr,
189+
x1_stride: tl.constexpr,
190+
# pyre-fixme[2]: Parameter must be annotated.
191+
x2_ptr,
192+
x2_stride: tl.constexpr,
193+
# pyre-fixme[2]: Parameter must be annotated.
194+
y_ptr,
195+
D: tl.constexpr,
196+
BLOCK_SIZE: tl.constexpr,
197+
):
198+
b = tl.program_id(0).to(tl.int64)
199+
200+
x1_start = x1_ptr + b * x1_stride
201+
x2_start = x2_ptr + b * x2_stride
202+
y_start = y_ptr + b * D
203+
204+
for offset in range(0, D, BLOCK_SIZE):
205+
cols = offset + tl.arange(0, BLOCK_SIZE)
206+
mask = cols < D
207+
x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32)
208+
x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32)
209+
yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16)
210+
tl.store(y_start + cols, yv, mask=mask)
211+
212+
213+
sigmoid = torch.nn.Sigmoid()
214+
215+
216+
def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
217+
assert x1.shape == x2.shape
218+
(B_T, D) = x1.shape
219+
out = torch.empty_like(x1)
220+
assert x1.stride(1) == x2.stride(1) == 1
221+
assert out.is_contiguous()
222+
grid = (B_T,)
223+
_silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024)
224+
return out
225+
226+
227+
def eager_ffn(x, w13, w2):
228+
p = x @ w13.T
229+
H_D_2, D = w13.shape
230+
H_D = H_D_2 // 2
231+
p1 = p[:, :H_D] # B_T, H_D
232+
p2 = p[:, H_D:] # B_T, H_D
233+
p_out = silu_mul(p1, p2) # B_T, H_D
234+
out = p_out @ w2
235+
return out, p_out
236+
237+
238+
def nunerics_check(shape):
239+
B_T, H_D, D = shape
240+
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
241+
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
242+
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
243+
triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True)
244+
eager_out, eager_p = eager_ffn(x, w13, w2)
245+
246+
print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2))
247+
# print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0))
248+
# print(triton_p[-1])
249+
# print(eager_p[-1])
250+
# print(ref_p[-1])
251+
252+
253+
def do_benchmark():
254+
255+
D = 2048
256+
H_D = 8192
257+
258+
configs = []
259+
configs.append(
260+
triton.testing.Benchmark(
261+
x_names=[
262+
"B_T",
263+
"H_D",
264+
"D",
265+
], # Argument names to use as an x-axis for the plot
266+
x_vals=[
267+
(i, H_D, D)
268+
for H_D, D in [(128, 256), (1024, 512), (8192, 2048)]
269+
for i in [1024, 2048, 4096, 8192, 16384]
270+
], # Different possible values for `x_name`
271+
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
272+
# Possible values for `line_arg`
273+
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
274+
line_vals=["eager", "fused"],
275+
line_names=["Eager", "Fused"],
276+
styles=[("green", "-"), ("blue", "-")],
277+
ylabel="Latency(ms)", # Label name for the y-axis
278+
plot_name="fused_ffn-benchmark",
279+
args={},
280+
)
281+
)
282+
283+
@triton.testing.perf_report(configs)
284+
def benchmark(B_T, H_D, D, provider):
285+
# breakpoint()
286+
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda")
287+
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda")
288+
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda")
289+
quantiles = [0.5, 0.2, 0.8]
290+
if provider == "eager":
291+
return triton.testing.do_bench(
292+
lambda: eager_ffn(x, w13, w2), quantiles=quantiles
293+
)
294+
if provider == "fused":
295+
return triton.testing.do_bench(
296+
lambda: fused_ffn(x, w13, w2), quantiles=quantiles
297+
)
298+
299+
benchmark.run(show_plots=True, print_data=True)
300+
301+
302+
if __name__ == "__main__":
303+
# B_T, H_D, D
304+
nunerics_check((64, 128, 128))
305+
# nunerics_check((256, 8192, 2048))
306+
# do_benchmark()

0 commit comments

Comments
 (0)