|
| 1 | +from typing import Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | + |
| 8 | +@triton.autotune( |
| 9 | + configs=[ |
| 10 | + triton.Config( |
| 11 | + # B_T, H_D (8192), D (2048) |
| 12 | + {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K}, |
| 13 | + num_stages=num_stages, |
| 14 | + num_warps=num_warps, |
| 15 | + ) |
| 16 | + for BLOCK_M in [32] |
| 17 | + for BLOCK_N in [128, 256] |
| 18 | + for BLOCK_K in [128, 256] |
| 19 | + for num_stages in [2, 4, 8] |
| 20 | + for num_warps in [8] |
| 21 | + ], |
| 22 | + key=["B_T", "D", "H_D"], |
| 23 | +) |
| 24 | +@triton.jit |
| 25 | +def fused_ffn_fwd( |
| 26 | + x_ptr, |
| 27 | + w13_ptr, |
| 28 | + w2_ptr, |
| 29 | + output_ptr, |
| 30 | + p_ptr, |
| 31 | + B_T, |
| 32 | + stride_xa, |
| 33 | + stride_xb, |
| 34 | + stride_w13a, |
| 35 | + stride_w13b, |
| 36 | + stride_w2a, |
| 37 | + stride_w2b, |
| 38 | + stride_oa, |
| 39 | + stride_ob, |
| 40 | + stride_pa, |
| 41 | + stride_pb, |
| 42 | + D: tl.constexpr, |
| 43 | + H_D: tl.constexpr, |
| 44 | + BLOCK_M: tl.constexpr, |
| 45 | + BLOCK_N: tl.constexpr, |
| 46 | + BLOCK_K: tl.constexpr, |
| 47 | +): |
| 48 | + pid_m = tl.program_id(axis=0) |
| 49 | + dtype = x_ptr.dtype.element_ty |
| 50 | + |
| 51 | + X_block_ptr = tl.make_block_ptr( |
| 52 | + base=x_ptr, |
| 53 | + shape=(B_T, D), |
| 54 | + strides=(stride_xa, stride_xb), |
| 55 | + offsets=(pid_m * BLOCK_M, 0), |
| 56 | + block_shape=(BLOCK_M, BLOCK_K), |
| 57 | + order=(1, 0), |
| 58 | + ) |
| 59 | + O_block_ptr = tl.make_block_ptr( |
| 60 | + base=output_ptr, |
| 61 | + shape=(B_T, D), |
| 62 | + strides=(stride_oa, stride_ob), |
| 63 | + offsets=(pid_m * BLOCK_M, 0), |
| 64 | + block_shape=(BLOCK_M, BLOCK_K), |
| 65 | + order=(1, 0), |
| 66 | + ) |
| 67 | + |
| 68 | + for start_n in range(0, H_D, BLOCK_N): |
| 69 | + P_block_ptr = tl.make_block_ptr( |
| 70 | + base=p_ptr, |
| 71 | + shape=(B_T, H_D), |
| 72 | + strides=(stride_pa, stride_pb), |
| 73 | + offsets=(pid_m * BLOCK_M, start_n), |
| 74 | + block_shape=(BLOCK_M, BLOCK_N), |
| 75 | + order=(1, 0), |
| 76 | + ) |
| 77 | + w1t_bptr = tl.make_block_ptr( |
| 78 | + base=w13_ptr, |
| 79 | + shape=(D, H_D), |
| 80 | + strides=(stride_w13b, stride_w13a), |
| 81 | + offsets=(0, start_n), |
| 82 | + block_shape=(BLOCK_K, BLOCK_N), |
| 83 | + order=(0, 1), |
| 84 | + ) |
| 85 | + w3t_bptr = tl.make_block_ptr( |
| 86 | + base=w13_ptr, |
| 87 | + shape=(D, H_D), |
| 88 | + strides=(stride_w13b, stride_w13a), |
| 89 | + offsets=(0, H_D + start_n), |
| 90 | + block_shape=(BLOCK_K, BLOCK_N), |
| 91 | + order=(0, 1), |
| 92 | + ) |
| 93 | + w2_bptr = tl.make_block_ptr( |
| 94 | + base=w2_ptr, |
| 95 | + shape=(H_D, D), |
| 96 | + strides=(stride_w2a, stride_w2b), |
| 97 | + offsets=(0, 0), |
| 98 | + block_shape=(BLOCK_N, BLOCK_K), |
| 99 | + order=(1, 0), |
| 100 | + ) |
| 101 | + |
| 102 | + x_bptr = X_block_ptr |
| 103 | + o_bptr = O_block_ptr |
| 104 | + acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 105 | + acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| 106 | + # first GEMM |
| 107 | + w1t_bptr_inner = w1t_bptr |
| 108 | + w3t_bptr_inner = w3t_bptr |
| 109 | + w2_bptr_inner = w2_bptr |
| 110 | + for _ in range(0, D, BLOCK_K): |
| 111 | + x = tl.load(x_bptr) |
| 112 | + w1t = tl.load(w1t_bptr_inner) |
| 113 | + w3t = tl.load(w3t_bptr_inner) |
| 114 | + acc_1 = tl.dot(x, w1t, acc_1) |
| 115 | + acc_3 = tl.dot(x, w3t, acc_3) |
| 116 | + x_bptr = tl.advance(x_bptr, (0, BLOCK_K)) |
| 117 | + w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0)) |
| 118 | + w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0)) |
| 119 | + # acc_1 = acc_1.to(dtype).to(tl.float32) |
| 120 | + # acc_3 = acc_3.to(dtype).to(tl.float32) |
| 121 | + p = acc_1 * tl.sigmoid(acc_1) * acc_3 |
| 122 | + p = p.to(dtype) |
| 123 | + tl.store(P_block_ptr, p) |
| 124 | + # second GEMM |
| 125 | + for _ in range(0, BLOCK_K, BLOCK_K): |
| 126 | + w2 = tl.load(w2_bptr) |
| 127 | + o = tl.load(o_bptr) |
| 128 | + tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype)) |
| 129 | + w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K)) |
| 130 | + o_bptr = tl.advance(o_bptr, (0, BLOCK_K)) |
| 131 | + |
| 132 | + |
| 133 | +def fused_ffn( |
| 134 | + x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor |
| 135 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 136 | + # x: [B_T, D] |
| 137 | + # w13: [H_D*2, D] |
| 138 | + # w2: [H_D, D] |
| 139 | + # output: [B_T, D] |
| 140 | + B_T, D = x.shape |
| 141 | + H_D_2, D = w13.shape |
| 142 | + H_D = w2.shape[0] |
| 143 | + assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}" |
| 144 | + |
| 145 | + def grid(META): |
| 146 | + return (triton.cdiv(B_T, META["BLOCK_M"]),) |
| 147 | + |
| 148 | + output = torch.empty_like(x) |
| 149 | + p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device) |
| 150 | + |
| 151 | + fused_ffn_fwd[grid]( |
| 152 | + x, |
| 153 | + w13, |
| 154 | + w2, |
| 155 | + output, |
| 156 | + p, |
| 157 | + B_T, |
| 158 | + x.stride(0), |
| 159 | + x.stride(1), |
| 160 | + w13.stride(0), |
| 161 | + w13.stride(1), |
| 162 | + w2.stride(0), |
| 163 | + w2.stride(1), |
| 164 | + output.stride(0), |
| 165 | + output.stride(1), |
| 166 | + p.stride(0), |
| 167 | + p.stride(1), |
| 168 | + D, |
| 169 | + H_D, |
| 170 | + ) |
| 171 | + |
| 172 | + return output |
| 173 | + |
| 174 | + |
| 175 | +@triton.jit |
| 176 | +# pyre-fixme[3]: Return type must be annotated. |
| 177 | +def _silu_mul_kernel( |
| 178 | + # pyre-fixme[2]: Parameter must be annotated. |
| 179 | + x1_ptr, |
| 180 | + x1_stride: tl.constexpr, |
| 181 | + # pyre-fixme[2]: Parameter must be annotated. |
| 182 | + x2_ptr, |
| 183 | + x2_stride: tl.constexpr, |
| 184 | + # pyre-fixme[2]: Parameter must be annotated. |
| 185 | + y_ptr, |
| 186 | + D: tl.constexpr, |
| 187 | + BLOCK_SIZE: tl.constexpr, |
| 188 | +): |
| 189 | + b = tl.program_id(0).to(tl.int64) |
| 190 | + |
| 191 | + x1_start = x1_ptr + b * x1_stride |
| 192 | + x2_start = x2_ptr + b * x2_stride |
| 193 | + y_start = y_ptr + b * D |
| 194 | + |
| 195 | + for offset in range(0, D, BLOCK_SIZE): |
| 196 | + cols = offset + tl.arange(0, BLOCK_SIZE) |
| 197 | + mask = cols < D |
| 198 | + x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32) |
| 199 | + x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32) |
| 200 | + yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16) |
| 201 | + tl.store(y_start + cols, yv, mask=mask) |
| 202 | + |
| 203 | + |
| 204 | +sigmoid = torch.nn.Sigmoid() |
| 205 | + |
| 206 | + |
| 207 | +def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: |
| 208 | + assert x1.shape == x2.shape |
| 209 | + (B_T, D) = x1.shape |
| 210 | + out = torch.empty_like(x1) |
| 211 | + assert x1.stride(1) == x2.stride(1) == 1 |
| 212 | + assert out.is_contiguous() |
| 213 | + grid = (B_T,) |
| 214 | + _silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024) |
| 215 | + return out |
| 216 | + |
| 217 | + |
| 218 | +def _ffn(x, w13, w2): |
| 219 | + p = x @ w13.T |
| 220 | + H_D_2, D = w13.shape |
| 221 | + H_D = H_D_2 // 2 |
| 222 | + p1 = p[:, :H_D] # B_T, H_D |
| 223 | + p2 = p[:, H_D:] # B_T, H_D |
| 224 | + p_out = silu_mul(p1, p2) # B_T, H_D |
| 225 | + out = p_out @ w2 |
| 226 | + return out |
| 227 | + |
| 228 | + |
| 229 | +def nunerics_check(shape): |
| 230 | + B_T, H_D, D = shape |
| 231 | + x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") |
| 232 | + w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") |
| 233 | + w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") |
| 234 | + triton_out, triton_p = fused_ffn(x, w13, w2) |
| 235 | + eager_out, eager_p, ref_p = _ffn(x, w13, w2) |
| 236 | + |
| 237 | + print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=0)) |
| 238 | + print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0)) |
| 239 | + # print(triton_p[-1]) |
| 240 | + # print(eager_p[-1]) |
| 241 | + # print(ref_p[-1]) |
| 242 | + |
| 243 | + |
| 244 | +def do_benchmark(): |
| 245 | + |
| 246 | + D = 2048 |
| 247 | + H_D = 8192 |
| 248 | + |
| 249 | + configs = [] |
| 250 | + configs.append( |
| 251 | + triton.testing.Benchmark( |
| 252 | + x_names=[ |
| 253 | + "B_T", |
| 254 | + "H_D", |
| 255 | + "D", |
| 256 | + ], # Argument names to use as an x-axis for the plot |
| 257 | + x_vals=[ |
| 258 | + (i, H_D, D) for H_D, D in [(128, 256), (1024, 512), (8192, 2048)] for i in [1024, 2048, 4096, 8192, 16384] |
| 259 | + ], # Different possible values for `x_name` |
| 260 | + line_arg="provider", # Argument name whose value corresponds to a different line in the plot |
| 261 | + # Possible values for `line_arg` |
| 262 | + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. |
| 263 | + line_vals=["eager", "fused"], |
| 264 | + line_names=["Eager", "Fused"], |
| 265 | + styles=[("green", "-"), ("blue", "-")], |
| 266 | + ylabel="Latency(ms)", # Label name for the y-axis |
| 267 | + plot_name="fused_ffn-benchmark", |
| 268 | + args={}, |
| 269 | + ) |
| 270 | + ) |
| 271 | + |
| 272 | + @triton.testing.perf_report(configs) |
| 273 | + def benchmark(B_T, H_D, D, provider): |
| 274 | + # breakpoint() |
| 275 | + x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") |
| 276 | + w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") |
| 277 | + w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") |
| 278 | + quantiles = [0.5, 0.2, 0.8] |
| 279 | + if provider == "eager": |
| 280 | + return triton.testing.do_bench( |
| 281 | + lambda: _ffn(x, w13, w2), quantiles=quantiles |
| 282 | + ) |
| 283 | + if provider == "fused": |
| 284 | + return triton.testing.do_bench( |
| 285 | + lambda: fused_ffn(x, w13, w2), quantiles=quantiles |
| 286 | + ) |
| 287 | + |
| 288 | + benchmark.run(show_plots=True, print_data=True) |
| 289 | + |
| 290 | + |
| 291 | +if __name__ == "__main__": |
| 292 | + # B_T, H_D, D |
| 293 | + # nunerics_check((16, 128, 128)) |
| 294 | + # nunerics_check((256, 8192, 2048)) |
| 295 | + do_benchmark() |
0 commit comments