Skip to content

Commit 47b1442

Browse files
Dhruva Kaushalfacebook-github-bot
Dhruva Kaushal
authored andcommitted
Adding fp8_gemm_grouped to tritonbench
Summary: The benchmarking code lives in operator.py which calls into the triton kernel inside grouped_gemm. Here's how the output looks: ``` (group_size, M, N, K) _triton-tflops _triton-gbps ------------------------ ---------------- -------------- (2, 4096, 4096, 4096) 1540.59 470.151 (2, 8192, 8192, 8192) 1635.58 249.57 (2, 16384, 16384, 16384) 1623.3 123.848 (4, 4096, 4096, 4096) 3083.73 658.756 (4, 8192, 8192, 8192) 3514.4 375.378 (4, 16384, 16384, 16384) 3535.54 188.818 average 2488.86 344.42 ``` Currently cutlass is disabled until they land (D69544396) so that the triton and cutlass have the same function signature (and we can use the benchmark operator's get_inpuit_iter method). TODO: enable cutlass. Reviewed By: manman-ren Differential Revision: D70947086 fbshipit-source-id: 096ba7ba8c603c2dfd0cd1a7307722398ee48d90
1 parent 352be40 commit 47b1442

File tree

3 files changed

+1073
-0
lines changed

3 files changed

+1073
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
"""
2+
Group GEMM
3+
============================
4+
This group gemm kernel launches a fixed number of CTA to compute a group
5+
of gemms. The scheduling is static and we do it on device.
6+
"""
7+
8+
# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved.
9+
#
10+
# Permission is hereby granted, free of charge, to any person obtaining
11+
# a copy of this software and associated documentation files
12+
# (the "Software"), to deal in the Software without restriction,
13+
# including without limitation the rights to use, copy, modify, merge,
14+
# publish, distribute, sublicense, and/or sell copies of the Software,
15+
# and to permit persons to whom the Software is furnished to do so,
16+
# subject to the following conditions:
17+
#
18+
# The above copyright notice and this permission notice shall be
19+
# included in all copies or substantial portions of the Software.
20+
#
21+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
24+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
25+
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
26+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
27+
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28+
29+
from typing import Optional
30+
31+
import torch
32+
33+
import triton
34+
import triton.language as tl
35+
36+
DEVICE = triton.runtime.driver.active.get_current_device()
37+
38+
39+
def is_cuda():
40+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
41+
42+
43+
def supports_tma():
44+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
45+
46+
47+
def num_sms():
48+
if is_cuda():
49+
return torch.cuda.get_device_properties("cuda").multi_processor_count
50+
return 148
51+
52+
53+
@triton.autotune(
54+
configs=[
55+
triton.Config(
56+
{
57+
"BLOCK_SIZE_M": 128,
58+
"BLOCK_SIZE_N": 128,
59+
"BLOCK_SIZE_K": 32,
60+
"NUM_SM": 84,
61+
}
62+
),
63+
triton.Config(
64+
{
65+
"BLOCK_SIZE_M": 128,
66+
"BLOCK_SIZE_N": 128,
67+
"BLOCK_SIZE_K": 32,
68+
"NUM_SM": 128,
69+
}
70+
),
71+
triton.Config(
72+
{
73+
"BLOCK_SIZE_M": 64,
74+
"BLOCK_SIZE_N": 64,
75+
"BLOCK_SIZE_K": 32,
76+
"NUM_SM": 84,
77+
}
78+
),
79+
triton.Config(
80+
{
81+
"BLOCK_SIZE_M": 64,
82+
"BLOCK_SIZE_N": 64,
83+
"BLOCK_SIZE_K": 32,
84+
"NUM_SM": 128,
85+
}
86+
),
87+
triton.Config(
88+
{
89+
"BLOCK_SIZE_M": 128,
90+
"BLOCK_SIZE_N": 128,
91+
"BLOCK_SIZE_K": 64,
92+
"NUM_SM": num_sms(),
93+
}
94+
),
95+
triton.Config(
96+
{
97+
"BLOCK_SIZE_M": 64,
98+
"BLOCK_SIZE_N": 128,
99+
"BLOCK_SIZE_K": 64,
100+
"NUM_SM": num_sms(),
101+
}
102+
),
103+
],
104+
key=["group_size"],
105+
)
106+
@triton.jit
107+
def grouped_matmul_kernel(
108+
# device tensor of matrices pointers
109+
group_a_ptrs,
110+
group_b_ptrs,
111+
group_c_ptrs,
112+
# device tensor of gemm sizes. its shape is [group_size, 3]
113+
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
114+
group_gemm_sizes,
115+
# device tensor of leading dimension sizes. its shape is [group_size, 3]
116+
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
117+
g_lds,
118+
# number of gemms
119+
group_size,
120+
# number of virtual SM
121+
NUM_SM: tl.constexpr,
122+
# tile sizes
123+
BLOCK_SIZE_M: tl.constexpr,
124+
BLOCK_SIZE_N: tl.constexpr,
125+
BLOCK_SIZE_K: tl.constexpr,
126+
):
127+
tile_idx = tl.program_id(0)
128+
last_problem_end = 0
129+
for g in range(group_size):
130+
# get the gemm size of the current problem
131+
gm = tl.load(group_gemm_sizes + g * 3)
132+
gn = tl.load(group_gemm_sizes + g * 3 + 1)
133+
gk = tl.load(group_gemm_sizes + g * 3 + 2)
134+
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
135+
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
136+
num_tiles = num_m_tiles * num_n_tiles
137+
# iterate through the tiles in the current gemm problem
138+
while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
139+
# pick up a tile from the current gemm problem
140+
k = gk
141+
lda = tl.load(g_lds + g * 3)
142+
ldb = tl.load(g_lds + g * 3 + 1)
143+
ldc = tl.load(g_lds + g * 3 + 2)
144+
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
145+
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
146+
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
147+
# figure out tile coordinates
148+
tile_idx_in_gemm = tile_idx - last_problem_end
149+
tile_m_idx = tile_idx_in_gemm // num_n_tiles
150+
tile_n_idx = tile_idx_in_gemm % num_n_tiles
151+
152+
# do regular gemm here
153+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
154+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
155+
offs_k = tl.arange(0, BLOCK_SIZE_K)
156+
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
157+
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
158+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
159+
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
160+
# hint to Triton compiler to do proper loop pipelining
161+
tl.multiple_of(a_ptrs, [16, 16])
162+
tl.multiple_of(b_ptrs, [16, 16])
163+
# assume full tile for now
164+
a = tl.load(a_ptrs)
165+
b = tl.load(b_ptrs)
166+
accumulator += tl.dot(a, b)
167+
a_ptrs += BLOCK_SIZE_K
168+
b_ptrs += BLOCK_SIZE_K * ldb
169+
c = accumulator.to(tl.float16)
170+
171+
offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
172+
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
173+
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]
174+
175+
# assumes full tile for now
176+
tl.store(c_ptrs, c)
177+
178+
# go to the next tile by advancing NUM_SM
179+
tile_idx += NUM_SM
180+
181+
# get ready to go to the next gemm problem
182+
last_problem_end = last_problem_end + num_tiles
183+
184+
185+
def group_gemm_fn(group_A, group_B):
186+
assert len(group_A) == len(group_B)
187+
group_size = len(group_A)
188+
189+
A_addrs = []
190+
B_addrs = []
191+
C_addrs = []
192+
g_sizes = []
193+
g_lds = []
194+
group_C = []
195+
for i in range(group_size):
196+
A = group_A[i]
197+
B = group_B[i]
198+
assert A.shape[1] == B.shape[0]
199+
M, K = A.shape
200+
K, N = B.shape
201+
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
202+
group_C.append(C)
203+
A_addrs.append(A.data_ptr())
204+
B_addrs.append(B.data_ptr())
205+
C_addrs.append(C.data_ptr())
206+
g_sizes += [M, N, K]
207+
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
208+
209+
# note these are device tensors
210+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
211+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
212+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
213+
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
214+
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
215+
# we use a fixed number of CTA, and it's auto-tunable
216+
grid = lambda META: (META["NUM_SM"],)
217+
grouped_matmul_kernel[grid](
218+
d_a_ptrs,
219+
d_b_ptrs,
220+
d_c_ptrs,
221+
d_g_sizes,
222+
d_g_lds,
223+
group_size,
224+
)
225+
226+
return group_C
227+
228+
229+
tma_configs = [
230+
triton.Config(
231+
{"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK},
232+
num_stages=s,
233+
num_warps=w,
234+
)
235+
for BM in [128]
236+
for BN in [128, 256]
237+
for BK in [64, 128]
238+
for s in ([3, 4])
239+
for w in [4, 8]
240+
]
241+
242+
243+
@triton.autotune(
244+
tma_configs,
245+
key=["group_a_ptrs", "group_b_ptrs", "gropup_c_ptrs", "group_size"],
246+
)
247+
@triton.jit
248+
def grouped_matmul_tma_kernel(
249+
# device tensor of matrices pointers
250+
group_a_ptrs,
251+
group_b_ptrs,
252+
group_c_ptrs,
253+
# device tensor of gemm sizes. its shape is [group_size, 3]
254+
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
255+
group_gemm_sizes,
256+
# device tensor of leading dimension sizes. its shape is [group_size, 3]
257+
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
258+
g_lds,
259+
# number of gemms
260+
group_size,
261+
# number of virtual SM
262+
NUM_SM: tl.constexpr,
263+
# tile sizes
264+
BLOCK_SIZE_M: tl.constexpr,
265+
BLOCK_SIZE_N: tl.constexpr,
266+
BLOCK_SIZE_K: tl.constexpr,
267+
# is the output FP8 or FP16
268+
FP8: tl.constexpr,
269+
):
270+
dtype = tl.float8e4nv
271+
tile_idx = tl.program_id(0)
272+
last_problem_end = 0
273+
for g in range(group_size):
274+
# get the gemm size of the current problem
275+
gm = tl.load(group_gemm_sizes + g * 3)
276+
gn = tl.load(group_gemm_sizes + g * 3 + 1)
277+
gk = tl.load(group_gemm_sizes + g * 3 + 2)
278+
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
279+
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
280+
num_tiles = num_m_tiles * num_n_tiles
281+
if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
282+
# pick up a tile from the current gemm problem
283+
lda = tl.load(g_lds + g * 3)
284+
ldb = tl.load(g_lds + g * 3 + 1)
285+
ldc = tl.load(g_lds + g * 3 + 2)
286+
287+
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
288+
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
289+
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))
290+
291+
a_desc = tl._experimental_make_tensor_descriptor(
292+
a_ptr,
293+
shape=[gm, gk],
294+
strides=[lda, 1],
295+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
296+
)
297+
298+
b_desc = tl._experimental_make_tensor_descriptor(
299+
b_ptr,
300+
shape=[gn, gk],
301+
strides=[ldb, 1],
302+
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
303+
)
304+
c_desc = tl._experimental_make_tensor_descriptor(
305+
c_ptr,
306+
shape=[gm, gn],
307+
strides=[ldc, 1],
308+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
309+
)
310+
311+
# iterate through the tiles in the current gemm problem
312+
while (
313+
tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
314+
):
315+
k = gk
316+
# figure out tile coordinates
317+
tile_idx_in_gemm = tile_idx - last_problem_end
318+
tile_m_idx = tile_idx_in_gemm // num_n_tiles
319+
tile_n_idx = tile_idx_in_gemm % num_n_tiles
320+
321+
# do regular gemm here
322+
offs_am = tile_m_idx * BLOCK_SIZE_M
323+
offs_bn = tile_n_idx * BLOCK_SIZE_N
324+
325+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
326+
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
327+
a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
328+
b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
329+
accumulator += tl.dot(a, b.T)
330+
331+
offs_cm = tile_m_idx * BLOCK_SIZE_M
332+
offs_cn = tile_n_idx * BLOCK_SIZE_N
333+
334+
c = accumulator.to(dtype)
335+
c_desc.store([offs_cm, offs_cn], c)
336+
337+
# go to the next tile by advancing NUM_SM
338+
tile_idx += NUM_SM
339+
340+
# get ready to go to the next gemm problem
341+
last_problem_end = last_problem_end + num_tiles
342+
343+
344+
def group_gemm_tma_fn(group_A, group_B):
345+
assert supports_tma()
346+
347+
assert len(group_A) == len(group_B)
348+
group_size = len(group_A)
349+
350+
A_addrs = []
351+
B_addrs = []
352+
C_addrs = []
353+
g_sizes = []
354+
g_lds = []
355+
group_C = []
356+
for i in range(group_size):
357+
A = group_A[i]
358+
B = group_B[i]
359+
assert A.shape[1] == B.shape[1]
360+
M, K = A.shape
361+
N, K = B.shape
362+
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
363+
group_C.append(C)
364+
A_addrs.append(A.data_ptr())
365+
B_addrs.append(B.data_ptr())
366+
C_addrs.append(C.data_ptr())
367+
g_sizes += [M, N, K]
368+
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
369+
# note these are device tensors
370+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
371+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
372+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
373+
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
374+
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
375+
376+
# we use a fixed number of CTA, and it's auto-tunable
377+
378+
# TMA descriptors require a global memory allocation
379+
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
380+
return torch.empty(size, device="cuda", dtype=torch.int8)
381+
382+
triton.set_allocator(alloc_fn)
383+
384+
grid = lambda META: (META["NUM_SM"],)
385+
grouped_matmul_tma_kernel[grid](
386+
d_a_ptrs,
387+
d_b_ptrs,
388+
d_c_ptrs,
389+
d_g_sizes,
390+
d_g_lds,
391+
group_size,
392+
FP8=torch.float8_e4m3fn == group_A[0].dtype,
393+
NUM_SM=num_sms(),
394+
)
395+
return group_C

0 commit comments

Comments
 (0)