|
| 1 | +""" |
| 2 | +Group GEMM |
| 3 | +============================ |
| 4 | +This group gemm kernel launches a fixed number of CTA to compute a group |
| 5 | +of gemms. The scheduling is static and we do it on device. |
| 6 | +""" |
| 7 | + |
| 8 | +# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved. |
| 9 | +# |
| 10 | +# Permission is hereby granted, free of charge, to any person obtaining |
| 11 | +# a copy of this software and associated documentation files |
| 12 | +# (the "Software"), to deal in the Software without restriction, |
| 13 | +# including without limitation the rights to use, copy, modify, merge, |
| 14 | +# publish, distribute, sublicense, and/or sell copies of the Software, |
| 15 | +# and to permit persons to whom the Software is furnished to do so, |
| 16 | +# subject to the following conditions: |
| 17 | +# |
| 18 | +# The above copyright notice and this permission notice shall be |
| 19 | +# included in all copies or substantial portions of the Software. |
| 20 | +# |
| 21 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
| 22 | +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
| 23 | +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
| 24 | +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
| 25 | +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 26 | +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
| 27 | +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| 28 | + |
| 29 | +from typing import Optional |
| 30 | + |
| 31 | +import torch |
| 32 | + |
| 33 | +import triton |
| 34 | +import triton.language as tl |
| 35 | + |
| 36 | +DEVICE = triton.runtime.driver.active.get_current_device() |
| 37 | + |
| 38 | + |
| 39 | +def is_cuda(): |
| 40 | + return triton.runtime.driver.active.get_current_target().backend == "cuda" |
| 41 | + |
| 42 | + |
| 43 | +def supports_tma(): |
| 44 | + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 |
| 45 | + |
| 46 | + |
| 47 | +def num_sms(): |
| 48 | + if is_cuda(): |
| 49 | + return torch.cuda.get_device_properties("cuda").multi_processor_count |
| 50 | + return 148 |
| 51 | + |
| 52 | + |
| 53 | +@triton.autotune( |
| 54 | + configs=[ |
| 55 | + triton.Config( |
| 56 | + { |
| 57 | + "BLOCK_SIZE_M": 128, |
| 58 | + "BLOCK_SIZE_N": 128, |
| 59 | + "BLOCK_SIZE_K": 32, |
| 60 | + "NUM_SM": 84, |
| 61 | + } |
| 62 | + ), |
| 63 | + triton.Config( |
| 64 | + { |
| 65 | + "BLOCK_SIZE_M": 128, |
| 66 | + "BLOCK_SIZE_N": 128, |
| 67 | + "BLOCK_SIZE_K": 32, |
| 68 | + "NUM_SM": 128, |
| 69 | + } |
| 70 | + ), |
| 71 | + triton.Config( |
| 72 | + { |
| 73 | + "BLOCK_SIZE_M": 64, |
| 74 | + "BLOCK_SIZE_N": 64, |
| 75 | + "BLOCK_SIZE_K": 32, |
| 76 | + "NUM_SM": 84, |
| 77 | + } |
| 78 | + ), |
| 79 | + triton.Config( |
| 80 | + { |
| 81 | + "BLOCK_SIZE_M": 64, |
| 82 | + "BLOCK_SIZE_N": 64, |
| 83 | + "BLOCK_SIZE_K": 32, |
| 84 | + "NUM_SM": 128, |
| 85 | + } |
| 86 | + ), |
| 87 | + triton.Config( |
| 88 | + { |
| 89 | + "BLOCK_SIZE_M": 128, |
| 90 | + "BLOCK_SIZE_N": 128, |
| 91 | + "BLOCK_SIZE_K": 64, |
| 92 | + "NUM_SM": num_sms(), |
| 93 | + } |
| 94 | + ), |
| 95 | + triton.Config( |
| 96 | + { |
| 97 | + "BLOCK_SIZE_M": 64, |
| 98 | + "BLOCK_SIZE_N": 128, |
| 99 | + "BLOCK_SIZE_K": 64, |
| 100 | + "NUM_SM": num_sms(), |
| 101 | + } |
| 102 | + ), |
| 103 | + ], |
| 104 | + key=["group_size"], |
| 105 | +) |
| 106 | +@triton.jit |
| 107 | +def grouped_matmul_kernel( |
| 108 | + # device tensor of matrices pointers |
| 109 | + group_a_ptrs, |
| 110 | + group_b_ptrs, |
| 111 | + group_c_ptrs, |
| 112 | + # device tensor of gemm sizes. its shape is [group_size, 3] |
| 113 | + # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm |
| 114 | + group_gemm_sizes, |
| 115 | + # device tensor of leading dimension sizes. its shape is [group_size, 3] |
| 116 | + # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm |
| 117 | + g_lds, |
| 118 | + # number of gemms |
| 119 | + group_size, |
| 120 | + # number of virtual SM |
| 121 | + NUM_SM: tl.constexpr, |
| 122 | + # tile sizes |
| 123 | + BLOCK_SIZE_M: tl.constexpr, |
| 124 | + BLOCK_SIZE_N: tl.constexpr, |
| 125 | + BLOCK_SIZE_K: tl.constexpr, |
| 126 | +): |
| 127 | + tile_idx = tl.program_id(0) |
| 128 | + last_problem_end = 0 |
| 129 | + for g in range(group_size): |
| 130 | + # get the gemm size of the current problem |
| 131 | + gm = tl.load(group_gemm_sizes + g * 3) |
| 132 | + gn = tl.load(group_gemm_sizes + g * 3 + 1) |
| 133 | + gk = tl.load(group_gemm_sizes + g * 3 + 2) |
| 134 | + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) |
| 135 | + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) |
| 136 | + num_tiles = num_m_tiles * num_n_tiles |
| 137 | + # iterate through the tiles in the current gemm problem |
| 138 | + while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: |
| 139 | + # pick up a tile from the current gemm problem |
| 140 | + k = gk |
| 141 | + lda = tl.load(g_lds + g * 3) |
| 142 | + ldb = tl.load(g_lds + g * 3 + 1) |
| 143 | + ldc = tl.load(g_lds + g * 3 + 2) |
| 144 | + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) |
| 145 | + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) |
| 146 | + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) |
| 147 | + # figure out tile coordinates |
| 148 | + tile_idx_in_gemm = tile_idx - last_problem_end |
| 149 | + tile_m_idx = tile_idx_in_gemm // num_n_tiles |
| 150 | + tile_n_idx = tile_idx_in_gemm % num_n_tiles |
| 151 | + |
| 152 | + # do regular gemm here |
| 153 | + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 154 | + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 155 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 156 | + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] |
| 157 | + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] |
| 158 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 159 | + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): |
| 160 | + # hint to Triton compiler to do proper loop pipelining |
| 161 | + tl.multiple_of(a_ptrs, [16, 16]) |
| 162 | + tl.multiple_of(b_ptrs, [16, 16]) |
| 163 | + # assume full tile for now |
| 164 | + a = tl.load(a_ptrs) |
| 165 | + b = tl.load(b_ptrs) |
| 166 | + accumulator += tl.dot(a, b) |
| 167 | + a_ptrs += BLOCK_SIZE_K |
| 168 | + b_ptrs += BLOCK_SIZE_K * ldb |
| 169 | + c = accumulator.to(tl.float16) |
| 170 | + |
| 171 | + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 172 | + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 173 | + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] |
| 174 | + |
| 175 | + # assumes full tile for now |
| 176 | + tl.store(c_ptrs, c) |
| 177 | + |
| 178 | + # go to the next tile by advancing NUM_SM |
| 179 | + tile_idx += NUM_SM |
| 180 | + |
| 181 | + # get ready to go to the next gemm problem |
| 182 | + last_problem_end = last_problem_end + num_tiles |
| 183 | + |
| 184 | + |
| 185 | +def group_gemm_fn(group_A, group_B): |
| 186 | + assert len(group_A) == len(group_B) |
| 187 | + group_size = len(group_A) |
| 188 | + |
| 189 | + A_addrs = [] |
| 190 | + B_addrs = [] |
| 191 | + C_addrs = [] |
| 192 | + g_sizes = [] |
| 193 | + g_lds = [] |
| 194 | + group_C = [] |
| 195 | + for i in range(group_size): |
| 196 | + A = group_A[i] |
| 197 | + B = group_B[i] |
| 198 | + assert A.shape[1] == B.shape[0] |
| 199 | + M, K = A.shape |
| 200 | + K, N = B.shape |
| 201 | + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) |
| 202 | + group_C.append(C) |
| 203 | + A_addrs.append(A.data_ptr()) |
| 204 | + B_addrs.append(B.data_ptr()) |
| 205 | + C_addrs.append(C.data_ptr()) |
| 206 | + g_sizes += [M, N, K] |
| 207 | + g_lds += [A.stride(0), B.stride(0), C.stride(0)] |
| 208 | + |
| 209 | + # note these are device tensors |
| 210 | + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) |
| 211 | + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) |
| 212 | + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) |
| 213 | + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) |
| 214 | + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) |
| 215 | + # we use a fixed number of CTA, and it's auto-tunable |
| 216 | + grid = lambda META: (META["NUM_SM"],) |
| 217 | + grouped_matmul_kernel[grid]( |
| 218 | + d_a_ptrs, |
| 219 | + d_b_ptrs, |
| 220 | + d_c_ptrs, |
| 221 | + d_g_sizes, |
| 222 | + d_g_lds, |
| 223 | + group_size, |
| 224 | + ) |
| 225 | + |
| 226 | + return group_C |
| 227 | + |
| 228 | + |
| 229 | +tma_configs = [ |
| 230 | + triton.Config( |
| 231 | + {"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK}, |
| 232 | + num_stages=s, |
| 233 | + num_warps=w, |
| 234 | + ) |
| 235 | + for BM in [128] |
| 236 | + for BN in [128, 256] |
| 237 | + for BK in [64, 128] |
| 238 | + for s in ([3, 4]) |
| 239 | + for w in [4, 8] |
| 240 | +] |
| 241 | + |
| 242 | + |
| 243 | +@triton.autotune( |
| 244 | + tma_configs, |
| 245 | + key=["group_a_ptrs", "group_b_ptrs", "gropup_c_ptrs", "group_size"], |
| 246 | +) |
| 247 | +@triton.jit |
| 248 | +def grouped_matmul_tma_kernel( |
| 249 | + # device tensor of matrices pointers |
| 250 | + group_a_ptrs, |
| 251 | + group_b_ptrs, |
| 252 | + group_c_ptrs, |
| 253 | + # device tensor of gemm sizes. its shape is [group_size, 3] |
| 254 | + # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm |
| 255 | + group_gemm_sizes, |
| 256 | + # device tensor of leading dimension sizes. its shape is [group_size, 3] |
| 257 | + # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm |
| 258 | + g_lds, |
| 259 | + # number of gemms |
| 260 | + group_size, |
| 261 | + # number of virtual SM |
| 262 | + NUM_SM: tl.constexpr, |
| 263 | + # tile sizes |
| 264 | + BLOCK_SIZE_M: tl.constexpr, |
| 265 | + BLOCK_SIZE_N: tl.constexpr, |
| 266 | + BLOCK_SIZE_K: tl.constexpr, |
| 267 | + # is the output FP8 or FP16 |
| 268 | + FP8: tl.constexpr, |
| 269 | +): |
| 270 | + dtype = tl.float8e4nv |
| 271 | + tile_idx = tl.program_id(0) |
| 272 | + last_problem_end = 0 |
| 273 | + for g in range(group_size): |
| 274 | + # get the gemm size of the current problem |
| 275 | + gm = tl.load(group_gemm_sizes + g * 3) |
| 276 | + gn = tl.load(group_gemm_sizes + g * 3 + 1) |
| 277 | + gk = tl.load(group_gemm_sizes + g * 3 + 2) |
| 278 | + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) |
| 279 | + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) |
| 280 | + num_tiles = num_m_tiles * num_n_tiles |
| 281 | + if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: |
| 282 | + # pick up a tile from the current gemm problem |
| 283 | + lda = tl.load(g_lds + g * 3) |
| 284 | + ldb = tl.load(g_lds + g * 3 + 1) |
| 285 | + ldc = tl.load(g_lds + g * 3 + 2) |
| 286 | + |
| 287 | + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype)) |
| 288 | + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype)) |
| 289 | + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype)) |
| 290 | + |
| 291 | + a_desc = tl._experimental_make_tensor_descriptor( |
| 292 | + a_ptr, |
| 293 | + shape=[gm, gk], |
| 294 | + strides=[lda, 1], |
| 295 | + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], |
| 296 | + ) |
| 297 | + |
| 298 | + b_desc = tl._experimental_make_tensor_descriptor( |
| 299 | + b_ptr, |
| 300 | + shape=[gn, gk], |
| 301 | + strides=[ldb, 1], |
| 302 | + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], |
| 303 | + ) |
| 304 | + c_desc = tl._experimental_make_tensor_descriptor( |
| 305 | + c_ptr, |
| 306 | + shape=[gm, gn], |
| 307 | + strides=[ldc, 1], |
| 308 | + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], |
| 309 | + ) |
| 310 | + |
| 311 | + # iterate through the tiles in the current gemm problem |
| 312 | + while ( |
| 313 | + tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles |
| 314 | + ): |
| 315 | + k = gk |
| 316 | + # figure out tile coordinates |
| 317 | + tile_idx_in_gemm = tile_idx - last_problem_end |
| 318 | + tile_m_idx = tile_idx_in_gemm // num_n_tiles |
| 319 | + tile_n_idx = tile_idx_in_gemm % num_n_tiles |
| 320 | + |
| 321 | + # do regular gemm here |
| 322 | + offs_am = tile_m_idx * BLOCK_SIZE_M |
| 323 | + offs_bn = tile_n_idx * BLOCK_SIZE_N |
| 324 | + |
| 325 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 326 | + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): |
| 327 | + a = a_desc.load([offs_am, kk * BLOCK_SIZE_K]) |
| 328 | + b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K]) |
| 329 | + accumulator += tl.dot(a, b.T) |
| 330 | + |
| 331 | + offs_cm = tile_m_idx * BLOCK_SIZE_M |
| 332 | + offs_cn = tile_n_idx * BLOCK_SIZE_N |
| 333 | + |
| 334 | + c = accumulator.to(dtype) |
| 335 | + c_desc.store([offs_cm, offs_cn], c) |
| 336 | + |
| 337 | + # go to the next tile by advancing NUM_SM |
| 338 | + tile_idx += NUM_SM |
| 339 | + |
| 340 | + # get ready to go to the next gemm problem |
| 341 | + last_problem_end = last_problem_end + num_tiles |
| 342 | + |
| 343 | + |
| 344 | +def group_gemm_tma_fn(group_A, group_B): |
| 345 | + assert supports_tma() |
| 346 | + |
| 347 | + assert len(group_A) == len(group_B) |
| 348 | + group_size = len(group_A) |
| 349 | + |
| 350 | + A_addrs = [] |
| 351 | + B_addrs = [] |
| 352 | + C_addrs = [] |
| 353 | + g_sizes = [] |
| 354 | + g_lds = [] |
| 355 | + group_C = [] |
| 356 | + for i in range(group_size): |
| 357 | + A = group_A[i] |
| 358 | + B = group_B[i] |
| 359 | + assert A.shape[1] == B.shape[1] |
| 360 | + M, K = A.shape |
| 361 | + N, K = B.shape |
| 362 | + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) |
| 363 | + group_C.append(C) |
| 364 | + A_addrs.append(A.data_ptr()) |
| 365 | + B_addrs.append(B.data_ptr()) |
| 366 | + C_addrs.append(C.data_ptr()) |
| 367 | + g_sizes += [M, N, K] |
| 368 | + g_lds += [A.stride(0), B.stride(0), C.stride(0)] |
| 369 | + # note these are device tensors |
| 370 | + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) |
| 371 | + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) |
| 372 | + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) |
| 373 | + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) |
| 374 | + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) |
| 375 | + |
| 376 | + # we use a fixed number of CTA, and it's auto-tunable |
| 377 | + |
| 378 | + # TMA descriptors require a global memory allocation |
| 379 | + def alloc_fn(size: int, alignment: int, stream: Optional[int]): |
| 380 | + return torch.empty(size, device="cuda", dtype=torch.int8) |
| 381 | + |
| 382 | + triton.set_allocator(alloc_fn) |
| 383 | + |
| 384 | + grid = lambda META: (META["NUM_SM"],) |
| 385 | + grouped_matmul_tma_kernel[grid]( |
| 386 | + d_a_ptrs, |
| 387 | + d_b_ptrs, |
| 388 | + d_c_ptrs, |
| 389 | + d_g_sizes, |
| 390 | + d_g_lds, |
| 391 | + group_size, |
| 392 | + FP8=torch.float8_e4m3fn == group_A[0].dtype, |
| 393 | + NUM_SM=num_sms(), |
| 394 | + ) |
| 395 | + return group_C |
0 commit comments