Skip to content

Commit 1118c9d

Browse files
sijiacfacebook-github-bot
authored andcommitted
Prototype (#2486)
Summary: Pull Request resolved: #2486 Differential Revision: D61055780
1 parent dde8528 commit 1118c9d

File tree

3 files changed

+449
-0
lines changed

3 files changed

+449
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import ast
7+
import copy
8+
import functools
9+
import linecache
10+
import os
11+
import sys
12+
import tempfile
13+
from typing import Any, Dict, List
14+
15+
import torch
16+
17+
import triton
18+
import triton.language as tl
19+
20+
21+
def get_cuda_autotune_config():
22+
return [
23+
triton.Config(
24+
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=2, num_warps=2
25+
),
26+
# triton.Config(
27+
# {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4
28+
# ),
29+
# triton.Config(
30+
# {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4
31+
# ),
32+
# triton.Config(
33+
# {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=4
34+
# ),
35+
# triton.Config(
36+
# {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4
37+
# ),
38+
# triton.Config(
39+
# {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=4
40+
# ),
41+
# triton.Config(
42+
# {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=2
43+
# ),
44+
# triton.Config(
45+
# {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=2
46+
# ),
47+
]
48+
49+
50+
def get_autotune_config():
51+
return get_cuda_autotune_config()
52+
53+
54+
@triton.autotune(
55+
configs=get_autotune_config(),
56+
key=["M", "D", "H_D"],
57+
)
58+
@triton.jit
59+
def fused_ffn_kernel(
60+
X_ptr,
61+
W13_ptr,
62+
W2_ptr,
63+
Y_ptr,
64+
P_out_ptr, # Output for intermediate results
65+
M,
66+
D,
67+
H_D, # Note: P is not needed as a parameter since P == D
68+
stride_xm,
69+
stride_xd,
70+
stride_w13a,
71+
stride_w13b,
72+
stride_w2n,
73+
stride_w2d, # Changed from stride_w2p to stride_w2d
74+
stride_ym,
75+
stride_yd, # Changed from stride_yp to stride_yd
76+
stride_poutm,
77+
stride_poutn,
78+
BLOCK_M: tl.constexpr,
79+
BLOCK_N: tl.constexpr,
80+
BLOCK_K: tl.constexpr, # This will be used for both D and P dimensions
81+
BLOCK_K_D: tl.constexpr, # This will be used for D dimension only
82+
):
83+
# Program IDs for M dimension
84+
pid_m = tl.program_id(0)
85+
86+
# Offsets for M
87+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
88+
mask_m = offs_m < M
89+
90+
# Initialize accumulator with float32 precision
91+
acc = tl.zeros((BLOCK_M, BLOCK_K_D), dtype=tl.float32)
92+
93+
# Loop over H_D in BLOCK_N chunks
94+
for start_n in range(0, H_D, BLOCK_N):
95+
offs_n = start_n + tl.arange(0, BLOCK_N)
96+
mask_n = offs_n < H_D
97+
98+
# Initialize partial results
99+
p1_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
100+
p2_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
101+
102+
# Block pointers for W13 (for p1 and p2)
103+
w1t_bptr = tl.make_block_ptr(
104+
base=W13_ptr,
105+
shape=(D, H_D),
106+
strides=(stride_w13b, stride_w13a),
107+
offsets=(0, start_n),
108+
block_shape=(BLOCK_K, BLOCK_N),
109+
order=(1, 0),
110+
)
111+
w3t_bptr = tl.make_block_ptr(
112+
base=W13_ptr,
113+
shape=(D, H_D),
114+
strides=(stride_w13b, stride_w13a),
115+
offsets=(0, H_D + start_n),
116+
block_shape=(BLOCK_K, BLOCK_N),
117+
order=(1, 0),
118+
)
119+
120+
# Loop over K (which is equal to D) in BLOCK_K chunks
121+
for k in range(0, D, BLOCK_K):
122+
offs_k = k + tl.arange(0, BLOCK_K)
123+
mask_k = offs_k < D
124+
125+
# Load X block
126+
x_bptr = tl.make_block_ptr(
127+
base=X_ptr,
128+
shape=(M, D),
129+
strides=(stride_xm, stride_xd),
130+
offsets=(pid_m * BLOCK_M, k),
131+
block_shape=(BLOCK_M, BLOCK_K),
132+
order=(1, 0),
133+
)
134+
X_block = tl.load(x_bptr, boundary_check=(0, 1), padding_option="zero")
135+
# X_block = tl.where(mask_m[:, None] & mask_k[None, :], X_block, 0.0).to(
136+
# tl.float16
137+
# )
138+
139+
# Load W1 and W3 blocks
140+
W1_block = tl.load(w1t_bptr)
141+
W3_block = tl.load(w3t_bptr)
142+
143+
# Perform GEMM operations
144+
p1_block += tl.dot(X_block, W1_block)
145+
p2_block += tl.dot(X_block, W3_block)
146+
147+
# Advance the block pointers
148+
w1t_bptr = tl.advance(w1t_bptr, (BLOCK_K, 0))
149+
w3t_bptr = tl.advance(w3t_bptr, (BLOCK_K, 0))
150+
151+
# Apply SiLU activation to p1 and multiply with p2
152+
p_out_block = p1_block * tl.sigmoid(p1_block) * p2_block
153+
# p_out_block = tl.where(mask_m[:, None] & mask_n[None, :], p_out_block, 0.0)
154+
155+
# Store P_out
156+
P_out_offs = P_out_ptr + (
157+
offs_m[:, None] * stride_poutm + offs_n[None, :] * stride_poutn
158+
)
159+
tl.store(
160+
P_out_offs,
161+
p_out_block.to(tl.float16),
162+
mask=mask_m[:, None] & mask_n[None, :],
163+
)
164+
165+
w2_bptr = tl.make_block_ptr(
166+
base=W2_ptr,
167+
shape=(H_D, D),
168+
strides=(stride_w2n, stride_w2d),
169+
offsets=(start_n, 0),
170+
block_shape=(BLOCK_N, BLOCK_K_D),
171+
order=(0, 1),
172+
)
173+
W2_block = tl.load(w2_bptr, boundary_check=(0, 1), padding_option="zero")
174+
175+
# Perform the second GEMM
176+
acc += tl.dot(p_out_block.to(tl.float16), W2_block)
177+
178+
offs_d = tl.arange(0, BLOCK_K_D)
179+
mask_d = offs_d < D
180+
y_offs = Y_ptr + offs_m[:, None] * stride_ym + offs_d[None, :] * stride_yd
181+
tl.store(y_offs, acc.to(tl.float16), mask=mask_m[:, None] & mask_d[None, :])
182+
183+
184+
def fused_ffn(
185+
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False
186+
):
187+
# x: [B_T, D]
188+
# w13: [H_D*2, D]
189+
# D = K
190+
# out1: [B_T, H_D]
191+
# w2: [H_D, P]
192+
# P = K
193+
# output: [B_T, P]
194+
B_T, D = x.shape
195+
H_D_2, D = w13.shape
196+
P, H_D = w2.shape
197+
assert D == P, f"D and P must be equal but got {D=} and {P=}"
198+
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}"
199+
200+
def grid(META):
201+
return (triton.cdiv(B_T, META["BLOCK_M"]),) # triton.cdiv(P, META["BLOCK_P"]))
202+
203+
output = torch.empty((B_T, P), dtype=x.dtype, device=x.device)
204+
if has_p:
205+
p_out = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device)
206+
else:
207+
p_out = torch.empty(1, dtype=x.dtype, device=x.device) # Dummy tensor
208+
209+
w2_t = w2.t().contiguous()
210+
211+
BLOCK_K_D = D
212+
213+
fused_ffn_kernel[grid](
214+
x,
215+
w13,
216+
w2_t,
217+
output,
218+
p_out,
219+
B_T,
220+
D,
221+
H_D,
222+
x.stride(0),
223+
x.stride(1),
224+
w13.stride(0),
225+
w13.stride(1),
226+
w2_t.stride(0),
227+
w2_t.stride(1),
228+
output.stride(0),
229+
output.stride(1),
230+
p_out.stride(0) if has_p else 0,
231+
p_out.stride(1) if has_p else 0,
232+
BLOCK_K_D=BLOCK_K_D,
233+
)
234+
235+
return output, p_out if has_p else None
236+
237+
238+
def eager_ffn(x, w13, w2):
239+
p = torch.matmul(x, w13.t())
240+
H_D_2, D = w13.shape
241+
H_D = H_D_2 // 2
242+
p1 = p[:, :H_D] # B_T, H_D
243+
p2 = p[:, H_D:] # B_T, H_D
244+
p_out = p1 * torch.sigmoid(p1) * p2
245+
out = torch.matmul(p_out, w2.t())
246+
return out, p_out
247+
248+
249+
def nunerics_check(shape):
250+
B_T, H_D, D = shape
251+
print(f"Running numeric check for {shape}")
252+
x = torch.randn((B_T, D), dtype=torch.float16, device="cuda")
253+
w13 = torch.randn((H_D * 2, D), dtype=torch.float16, device="cuda") * 0.1
254+
w2 = torch.randn((D, H_D), dtype=torch.float16, device="cuda") * 0.1
255+
triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True)
256+
eager_out, eager_p = eager_ffn(x, w13, w2)
257+
258+
if not torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2):
259+
print("P numeric check failed")
260+
print(f"triton output: {triton_p.flatten()[0:10]}")
261+
print(f"eager output: {eager_p.flatten()[0:10]}")
262+
else:
263+
print("P numeric check passed")
264+
if not torch.allclose(triton_out, eager_out, atol=1e-2, rtol=1e-2):
265+
print("Y numeric check failed")
266+
print(f"triton output: {triton_out.flatten()[0:10]}")
267+
print(f"eager output: {eager_out.flatten()[0:10]}")
268+
else:
269+
print("Y numeric check passed")
270+
271+
torch.testing.assert_close(triton_out, eager_out, atol=1e-2, rtol=1e-2)
272+
273+
274+
def do_benchmark():
275+
276+
D = 2048
277+
H_D = 8192
278+
279+
configs = []
280+
configs.append(
281+
triton.testing.Benchmark(
282+
x_names=[
283+
"B_T",
284+
"H_D",
285+
"D",
286+
], # Argument names to use as an x-axis for the plot
287+
x_vals=[
288+
(i, H_D, D)
289+
for H_D, D in [(5325, 4096)]
290+
for i in [1024, 2048, 4096, 8192, 16384]
291+
], # Different possible values for `x_name`
292+
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
293+
# Possible values for `line_arg`
294+
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
295+
line_vals=["eager", "fused"],
296+
line_names=["Eager", "Fused"],
297+
styles=[("green", "-"), ("blue", "-")],
298+
ylabel="Latency(ms)", # Label name for the y-axis
299+
plot_name="fused_ffn-benchmark",
300+
args={},
301+
)
302+
)
303+
304+
@triton.testing.perf_report(configs)
305+
def benchmark(B_T, H_D, D, provider):
306+
# breakpoint()
307+
x = torch.randn((B_T, D), dtype=torch.float16, device="cuda")
308+
w13 = torch.randn((H_D * 2, D), dtype=torch.float16, device="cuda")
309+
w2 = torch.randn((D, H_D), dtype=torch.float16, device="cuda")
310+
quantiles = [0.5, 0.2, 0.8]
311+
if provider == "eager":
312+
return triton.testing.do_bench(
313+
lambda: eager_ffn(x, w13, w2), quantiles=quantiles
314+
)
315+
if provider == "fused":
316+
return triton.testing.do_bench(
317+
lambda: fused_ffn(x, w13, w2), quantiles=quantiles
318+
)
319+
320+
benchmark.run(show_plots=True, print_data=True)
321+
322+
323+
if __name__ == "__main__":
324+
# B_T, H_D, D
325+
torch.manual_seed(0)
326+
nunerics_check((1024, 1024, 128))
327+
328+
# do_benchmark()

0 commit comments

Comments
 (0)