|
| 1 | +// Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +#include "mlx/backend/cuda/kernel_utils.cuh" |
| 4 | +#include "mlx/backend/cuda/quantized/qmm.h" |
| 5 | +#include "mlx/dtype_utils.h" |
| 6 | + |
| 7 | +#include <cute/layout.hpp> |
| 8 | +#include <cute/tensor.hpp> |
| 9 | + |
| 10 | +// clang-format off |
| 11 | + |
| 12 | +namespace cute { |
| 13 | + |
| 14 | +template <typename A, typename B> |
| 15 | +struct F32FMA { |
| 16 | + using C = float; |
| 17 | + using D = float; |
| 18 | + using DRegisters = D[1]; |
| 19 | + using ARegisters = A[1]; |
| 20 | + using BRegisters = B[1]; |
| 21 | + using CRegisters = C[1]; |
| 22 | + CUTE_HOST_DEVICE static void fma(D& d, const A& a, const B& b, const C& c) { |
| 23 | + d = float(a) * float(b) + c; |
| 24 | + } |
| 25 | +}; |
| 26 | + |
| 27 | +template <typename A, typename B> |
| 28 | +struct MMA_Traits<F32FMA<A,B>> { |
| 29 | + using ValTypeD = float; |
| 30 | + using ValTypeA = A; |
| 31 | + using ValTypeB = B; |
| 32 | + using ValTypeC = float; |
| 33 | + using Shape_MNK = Shape<_1,_1,_1>; |
| 34 | + using ThrID = Layout<_1>; |
| 35 | + using ALayout = Layout<Shape<_1,_1>>; |
| 36 | + using BLayout = Layout<Shape<_1,_1>>; |
| 37 | + using CLayout = Layout<Shape<_1,_1>>; |
| 38 | +}; |
| 39 | + |
| 40 | +} // namespace cute |
| 41 | + |
| 42 | +// We can't put kernel code in mlx::core due to name conflicts of "Shape". |
| 43 | +namespace cute_gemm { |
| 44 | + |
| 45 | +using namespace cute; |
| 46 | + |
| 47 | +template <typename ProblemShape, typename CtaTiler, |
| 48 | + typename Element, typename Quant, |
| 49 | + typename AStride, typename ASmemLayout, typename TiledCopyA, |
| 50 | + typename BStride, typename BSmemLayout, typename TiledCopyB, |
| 51 | + typename SLayout, typename CStride, typename TiledMma> |
| 52 | +__global__ void qmm_impl( |
| 53 | + ProblemShape shape_MNKL, CtaTiler cta_tiler, |
| 54 | + const Element* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, |
| 55 | + const Quant* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, |
| 56 | + const Element* S, const Element* Z, SLayout S_layout, |
| 57 | + Element* C, CStride dC, TiledMma mma) { |
| 58 | + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); |
| 59 | + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); |
| 60 | + CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); |
| 61 | + CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); |
| 62 | + CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); |
| 63 | + |
| 64 | + int thread_idx = int(threadIdx.x); |
| 65 | + auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx); |
| 66 | + |
| 67 | + // Represent the full tensors. |
| 68 | + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) |
| 69 | + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) |
| 70 | + Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) |
| 71 | + Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) |
| 72 | + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) |
| 73 | + |
| 74 | + // Get batch slice. |
| 75 | + Tensor mA = mA_mkl(_,_,l_coord); // (M,K) |
| 76 | + Tensor mB = mB_nkl(_,_,l_coord); // (N,K) |
| 77 | + Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) |
| 78 | + Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) |
| 79 | + Tensor mC = mC_mnl(_,_,l_coord); // (M,N) |
| 80 | + |
| 81 | + // Get the appropriate blocks for this thread block. |
| 82 | + auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) |
| 83 | + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) |
| 84 | + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) |
| 85 | + Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) |
| 86 | + Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) |
| 87 | + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) |
| 88 | + |
| 89 | + auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord |
| 90 | + |
| 91 | + // Shared memory buffers. |
| 92 | + __shared__ Element smemA[cosize_v<ASmemLayout>]; |
| 93 | + __shared__ Element smemB[cosize_v<BSmemLayout>]; |
| 94 | + Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) |
| 95 | + Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) |
| 96 | + |
| 97 | + // Partition the copying of A and B tiles across the threads. |
| 98 | + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); |
| 99 | + Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) |
| 100 | + Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) |
| 101 | + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) |
| 102 | + |
| 103 | + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); |
| 104 | + Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) |
| 105 | + Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) |
| 106 | + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) |
| 107 | + Tensor tBrBq = make_fragment_like<Quant>(tBsB); // (BCPY,BCPY_N,BCPY_K) |
| 108 | + Tensor tBgS = thr_copy_b.partition_S(gS); // (BCPY,BCPY_N,BCPY_K,k) |
| 109 | + Tensor tBgZ = thr_copy_b.partition_S(gZ); // (BCPY,BCPY_N,BCPY_K,k) |
| 110 | + |
| 111 | + // MMA. |
| 112 | + ThrMMA thr_mma = mma.get_slice(thread_idx); |
| 113 | + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) |
| 114 | + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) |
| 115 | + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) |
| 116 | + |
| 117 | + // Accumulators. |
| 118 | + Tensor tCrC = thr_mma.make_fragment_C(tCgC); |
| 119 | + clear(tCrC); |
| 120 | + |
| 121 | + // Predicates for m bounds. |
| 122 | + Tensor tApA = make_tensor<bool>(make_shape(size<1>(tAsA), size<2>(tAsA)), |
| 123 | + Stride<_1,_0>{}); // (ACPY_M,ACPY_K) |
| 124 | + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) |
| 125 | + Tensor tAcA = thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) |
| 126 | + CUTE_UNROLL |
| 127 | + for (int m = 0; m < size<0>(tApA); ++m) { |
| 128 | + tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; |
| 129 | + } |
| 130 | + |
| 131 | + // Copy gmem to rmem for k_tile=0. |
| 132 | + copy_if(copy_a, tApA, tAgA(_,_,_,0), tArA); |
| 133 | + copy(copy_b, tBgB(_,_,_,0), tBrBq); |
| 134 | + |
| 135 | + auto K_TILE_MAX = size<3>(tAgA); |
| 136 | + |
| 137 | + // Main loop. |
| 138 | + for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) { |
| 139 | + __syncthreads(); |
| 140 | + |
| 141 | + // Dequantize B and then copy A/B to smem. |
| 142 | + Tensor scale = tBgS(_,_,_,k_tile); |
| 143 | + Tensor zero_point = tBgZ(_,_,_,k_tile); |
| 144 | + for (int i = 0; i < size(tBrB); ++i) { |
| 145 | + tBrB(i) = tBrBq(i) * scale(i) + zero_point(i); |
| 146 | + } |
| 147 | + copy(tArA, tAsA); |
| 148 | + copy(tBrB, tBsB); |
| 149 | + __syncthreads(); |
| 150 | + |
| 151 | + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors. |
| 152 | + int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile; |
| 153 | + copy_if(copy_a, tApA, tAgA(_,_,_,k_tile_next), tArA); |
| 154 | + copy(copy_b, tBgB(_,_,_,k_tile_next), tBrBq); |
| 155 | + |
| 156 | + // Compute gemm on mma-partitioned smem. |
| 157 | + gemm(mma, tCsA, tCsB, tCrC); |
| 158 | + } |
| 159 | + |
| 160 | + copy(tCrC, tCgC); |
| 161 | +} |
| 162 | + |
| 163 | +template <typename Element, typename GroupSize, typename F> |
| 164 | +inline auto dispatch_swizzle(F&& f) { |
| 165 | + if constexpr (sizeof(Element) == 4) { |
| 166 | + if constexpr (GroupSize::value <= 32) { |
| 167 | + f(Swizzle<3,2,3>{}); |
| 168 | + } else { |
| 169 | + f(Swizzle<3,3,3>{}); |
| 170 | + } |
| 171 | + } else { |
| 172 | + if constexpr (GroupSize::value <= 32) { |
| 173 | + f(Swizzle<2,3,3>{}); |
| 174 | + } else { |
| 175 | + f(Swizzle<3,3,3>{}); |
| 176 | + } |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +template <typename Element, typename F> |
| 181 | +inline auto dispatch_mma(bool is_sm80, F&& f) { |
| 182 | + if (is_sm80) { |
| 183 | + if constexpr (std::is_same_v<Element, float>) { |
| 184 | + f(make_tiled_mma(SM80_16x8x8_F32TF32TF32F32_TN{}, |
| 185 | + Layout<Shape<_1,_4,_1>>{}, |
| 186 | + Tile<_16,_32,_8>{})); |
| 187 | + return; |
| 188 | + } else if constexpr (std::is_same_v<Element, cute::half_t>) { |
| 189 | + f(make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{}, |
| 190 | + Layout<Shape<_1,_4,_1>>{}, |
| 191 | + Tile<_16,_32,_16>{})); |
| 192 | + return; |
| 193 | + } |
| 194 | + } |
| 195 | + f(make_tiled_mma(F32FMA<Element, Element>{}, |
| 196 | + Layout<Shape<_16,_8,_1>>{})); |
| 197 | +} |
| 198 | + |
| 199 | +template <typename GroupSize, typename Element, typename Quant, typename F> |
| 200 | +void qmm( |
| 201 | + int m, int n, int k, int l, |
| 202 | + GroupSize group_size, |
| 203 | + const Element* A, |
| 204 | + const Quant* B, |
| 205 | + const Element* S, |
| 206 | + const Element* Z, |
| 207 | + Element* C, |
| 208 | + bool is_sm80, |
| 209 | + F&& launch_kernel) { |
| 210 | + // Define shapes (dynamic). |
| 211 | + auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) |
| 212 | + |
| 213 | + // Define TN strides (mixed). |
| 214 | + auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) |
| 215 | + auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) |
| 216 | + auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) |
| 217 | + |
| 218 | + // Define layout of scales (mixed). |
| 219 | + auto S_layout = make_layout( |
| 220 | + make_shape(n, make_shape(group_size, k / group_size), l), |
| 221 | + make_stride(k / group_size, make_stride(Int<0>{}, Int<1>{}), n * k / group_size)); |
| 222 | + |
| 223 | + // Define CTA tile sizes (static). |
| 224 | + auto bM = Int<16>{}; |
| 225 | + auto bN = Int<128>{}; |
| 226 | + auto bK = Int<max(64,group_size)>{}; |
| 227 | + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K) |
| 228 | + |
| 229 | + TiledCopy copy_a = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, Element>{}, |
| 230 | + Layout<Shape<_16,_8>,Stride<_8,_1>>{}, |
| 231 | + Layout<Shape< _1,_8>>{}); |
| 232 | + TiledCopy copy_b = make_tiled_copy(Copy_Atom<UniversalCopy<uint32_t>, Quant>{}, |
| 233 | + Layout<Shape<_16,_8>,Stride<_8,_1>>{}, |
| 234 | + Layout<Shape<_1,Int<32/sizeof_bits<Quant>::value>>>{}); |
| 235 | + |
| 236 | + // Define the smem layouts (static). |
| 237 | + dispatch_swizzle<Element, GroupSize>([&](auto swizzle) { |
| 238 | + auto swizzle_atom = composition(swizzle, |
| 239 | + Layout<Shape<_8,GroupSize>, |
| 240 | + Stride<GroupSize,_1>>{}); |
| 241 | + auto sA_layout = tile_to_shape(swizzle_atom, make_shape(bM, bK)); |
| 242 | + auto sB_layout = tile_to_shape(swizzle_atom, make_shape(bN, bK)); |
| 243 | + |
| 244 | + // Create tiled MMA. |
| 245 | + dispatch_mma<Element>(is_sm80, [&](auto mma) { |
| 246 | + // Launch kernel. |
| 247 | + auto* kernel = &qmm_impl< |
| 248 | + decltype(prob_shape), decltype(cta_tiler), |
| 249 | + Element, Quant, |
| 250 | + decltype(dA), decltype(sA_layout), decltype(copy_a), |
| 251 | + decltype(dB), decltype(sB_layout), decltype(copy_b), |
| 252 | + decltype(S_layout), decltype(dC), decltype(mma)>; |
| 253 | + dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l); |
| 254 | + dim3 block_dims(size(mma)); |
| 255 | + void* args[] = { |
| 256 | + &prob_shape, &cta_tiler, |
| 257 | + &A, &dA, &sA_layout, ©_a, |
| 258 | + &B, &dB, &sB_layout, ©_b, |
| 259 | + &S, &Z, &S_layout, |
| 260 | + &C, &dC, &mma}; |
| 261 | + launch_kernel(reinterpret_cast<void*>(kernel), num_blocks, block_dims, 0, args); |
| 262 | + }); |
| 263 | + }); |
| 264 | +} |
| 265 | + |
| 266 | +} // namespace cute_qmm |
| 267 | + |
| 268 | +// clang-format on |
| 269 | + |
| 270 | +namespace mlx::core { |
| 271 | + |
| 272 | +template <typename F> |
| 273 | +inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { |
| 274 | + if (dtype == float32) { |
| 275 | + f.template operator()<float>(); |
| 276 | + } else if (dtype == float16) { |
| 277 | + f.template operator()<cutlass::half_t>(); |
| 278 | + } else { |
| 279 | + throw std::invalid_argument( |
| 280 | + fmt::format( |
| 281 | + "[{0}] Unsupported dtype: {1}.", tag, dtype_to_string(dtype))); |
| 282 | + } |
| 283 | +} |
| 284 | + |
| 285 | +template <typename F> |
| 286 | +inline void dispatch_quant_types(int bits, const char* tag, F&& f) { |
| 287 | + if (bits == 8) { |
| 288 | + f.template operator()<uint8_t>(); |
| 289 | + } else { |
| 290 | + throw std::invalid_argument( |
| 291 | + fmt::format("[{0}] {1}-bit quantization is not supported.", tag, bits)); |
| 292 | + } |
| 293 | +} |
| 294 | + |
| 295 | +template <typename F> |
| 296 | +inline void dispatch_groups(int group_size, const char* tag, F&& f) { |
| 297 | + if (group_size == 16) { |
| 298 | + f(cute::Int<16>{}); |
| 299 | + } else if (group_size == 32) { |
| 300 | + f(cute::Int<32>{}); |
| 301 | + } else if (group_size == 64) { |
| 302 | + f(cute::Int<64>{}); |
| 303 | + } else { |
| 304 | + throw std::invalid_argument( |
| 305 | + fmt::format("[{0}] Group size {1} is not supported.", tag, group_size)); |
| 306 | + } |
| 307 | +} |
| 308 | + |
| 309 | +void cute_qmm( |
| 310 | + const array& x, |
| 311 | + const array& w, |
| 312 | + const array& scales, |
| 313 | + const array& biases, |
| 314 | + array& out, |
| 315 | + int bits, |
| 316 | + int group_size, |
| 317 | + cu::CommandEncoder& encoder) { |
| 318 | + const char* tag = "[quantized_matmul]"; |
| 319 | + int m = out.shape(-2); |
| 320 | + int n = out.shape(-1); |
| 321 | + int k = x.shape(-1); |
| 322 | + int l = out.size() / (m * n); |
| 323 | + if (n % 128 != 0) { |
| 324 | + throw std::runtime_error( |
| 325 | + fmt::format("[{0}] N must be multiples of 128.", tag)); |
| 326 | + } |
| 327 | + if (k % 64 != 0) { |
| 328 | + throw std::runtime_error( |
| 329 | + fmt::format("[{0}] K must be multiples of 64.", tag)); |
| 330 | + } |
| 331 | + if (l != 1 && m % 16 != 0) { |
| 332 | + throw std::runtime_error( |
| 333 | + fmt::format("[{0}] M must be multiples of 16 for batched GEMM.", tag)); |
| 334 | + } |
| 335 | + dispatch_element_types(out.dtype(), tag, [&]<typename Element>() { |
| 336 | + dispatch_quant_types(bits, tag, [&]<typename Quant>() { |
| 337 | + dispatch_groups(group_size, tag, [&](auto group_size) { |
| 338 | + encoder.set_input_array(x); |
| 339 | + encoder.set_input_array(w); |
| 340 | + encoder.set_input_array(scales); |
| 341 | + encoder.set_input_array(biases); |
| 342 | + encoder.set_output_array(out); |
| 343 | + cute_gemm::qmm( |
| 344 | + m, |
| 345 | + n, |
| 346 | + k, |
| 347 | + l, |
| 348 | + group_size, |
| 349 | + gpu_ptr<Element>(x), |
| 350 | + gpu_ptr<Quant>(w), |
| 351 | + gpu_ptr<Element>(scales), |
| 352 | + gpu_ptr<Element>(biases), |
| 353 | + gpu_ptr<Element>(out), |
| 354 | + encoder.device().compute_capability_major() >= 8, |
| 355 | + [&](auto* kernel, |
| 356 | + dim3 num_blocks, |
| 357 | + dim3 block_dims, |
| 358 | + uint32_t smem_bytes, |
| 359 | + void** args) { |
| 360 | + encoder.add_kernel_node( |
| 361 | + kernel, num_blocks, block_dims, smem_bytes, args); |
| 362 | + }); |
| 363 | + }); |
| 364 | + }); |
| 365 | + }); |
| 366 | +} |
| 367 | + |
| 368 | +} // namespace mlx::core |
0 commit comments