Skip to content

Commit 2fa4bca

Browse files
committed
[CUDA] fp16/32 x int8 quantized matmul
1 parent 3bbe87e commit 2fa4bca

File tree

4 files changed

+396
-4
lines changed

4 files changed

+396
-4
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ target_sources(
5656
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
5858
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
59+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
5960
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
6061
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
6162
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp

mlx/backend/cuda/quantized/qmm.cu

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/kernel_utils.cuh"
4+
#include "mlx/backend/cuda/quantized/qmm.h"
5+
#include "mlx/dtype_utils.h"
6+
7+
#include <cute/layout.hpp>
8+
#include <cute/tensor.hpp>
9+
10+
// clang-format off
11+
12+
namespace cute {
13+
14+
template <typename A, typename B>
15+
struct F32FMA {
16+
using C = float;
17+
using D = float;
18+
using DRegisters = D[1];
19+
using ARegisters = A[1];
20+
using BRegisters = B[1];
21+
using CRegisters = C[1];
22+
CUTE_HOST_DEVICE static void fma(D& d, const A& a, const B& b, const C& c) {
23+
d = float(a) * float(b) + c;
24+
}
25+
};
26+
27+
template <typename A, typename B>
28+
struct MMA_Traits<F32FMA<A,B>> {
29+
using ValTypeD = float;
30+
using ValTypeA = A;
31+
using ValTypeB = B;
32+
using ValTypeC = float;
33+
using Shape_MNK = Shape<_1,_1,_1>;
34+
using ThrID = Layout<_1>;
35+
using ALayout = Layout<Shape<_1,_1>>;
36+
using BLayout = Layout<Shape<_1,_1>>;
37+
using CLayout = Layout<Shape<_1,_1>>;
38+
};
39+
40+
} // namespace cute
41+
42+
// We can't put kernel code in mlx::core due to name conflicts of "Shape".
43+
namespace cute_gemm {
44+
45+
using namespace cute;
46+
47+
template <typename ProblemShape, typename CtaTiler,
48+
typename Element, typename Quant,
49+
typename AStride, typename ASmemLayout, typename TiledCopyA,
50+
typename BStride, typename BSmemLayout, typename TiledCopyB,
51+
typename SLayout, typename CStride, typename TiledMma>
52+
__global__ void qmm_impl(
53+
ProblemShape shape_MNKL, CtaTiler cta_tiler,
54+
const Element* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
55+
const Quant* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
56+
const Element* S, const Element* Z, SLayout S_layout,
57+
Element* C, CStride dC, TiledMma mma) {
58+
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma));
59+
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma));
60+
CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA));
61+
CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB));
62+
CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC));
63+
64+
int thread_idx = int(threadIdx.x);
65+
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
66+
67+
// Represent the full tensors.
68+
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
69+
Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
70+
Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L)
71+
Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L)
72+
Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L)
73+
74+
// Get batch slice.
75+
Tensor mA = mA_mkl(_,_,l_coord); // (M,K)
76+
Tensor mB = mB_nkl(_,_,l_coord); // (N,K)
77+
Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
78+
Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
79+
Tensor mC = mC_mnl(_,_,l_coord); // (M,N)
80+
81+
// Get the appropriate blocks for this thread block.
82+
auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k)
83+
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
84+
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
85+
Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
86+
Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
87+
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
88+
89+
auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord
90+
91+
// Shared memory buffers.
92+
__shared__ Element smemA[cosize_v<ASmemLayout>];
93+
__shared__ Element smemB[cosize_v<BSmemLayout>];
94+
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
95+
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
96+
97+
// Partition the copying of A and B tiles across the threads.
98+
ThrCopy thr_copy_a = copy_a.get_slice(thread_idx);
99+
Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k)
100+
Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K)
101+
Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K)
102+
103+
ThrCopy thr_copy_b = copy_b.get_slice(thread_idx);
104+
Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k)
105+
Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K)
106+
Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K)
107+
Tensor tBrBq = make_fragment_like<Quant>(tBsB); // (BCPY,BCPY_N,BCPY_K)
108+
Tensor tBgS = thr_copy_b.partition_S(gS); // (BCPY,BCPY_N,BCPY_K,k)
109+
Tensor tBgZ = thr_copy_b.partition_S(gZ); // (BCPY,BCPY_N,BCPY_K,k)
110+
111+
// MMA.
112+
ThrMMA thr_mma = mma.get_slice(thread_idx);
113+
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
114+
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
115+
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
116+
117+
// Accumulators.
118+
Tensor tCrC = thr_mma.make_fragment_C(tCgC);
119+
clear(tCrC);
120+
121+
// Predicates for m bounds.
122+
Tensor tApA = make_tensor<bool>(make_shape(size<1>(tAsA), size<2>(tAsA)),
123+
Stride<_1,_0>{}); // (ACPY_M,ACPY_K)
124+
Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K)
125+
Tensor tAcA = thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K)
126+
CUTE_UNROLL
127+
for (int m = 0; m < size<0>(tApA); ++m) {
128+
tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord;
129+
}
130+
131+
// Copy gmem to rmem for k_tile=0.
132+
copy_if(copy_a, tApA, tAgA(_,_,_,0), tArA);
133+
copy(copy_b, tBgB(_,_,_,0), tBrBq);
134+
135+
auto K_TILE_MAX = size<3>(tAgA);
136+
137+
// Main loop.
138+
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) {
139+
__syncthreads();
140+
141+
// Dequantize B and then copy A/B to smem.
142+
Tensor scale = tBgS(_,_,_,k_tile);
143+
Tensor zero_point = tBgZ(_,_,_,k_tile);
144+
for (int i = 0; i < size(tBrB); ++i) {
145+
tBrB(i) = tBrBq(i) * scale(i) + zero_point(i);
146+
}
147+
copy(tArA, tAsA);
148+
copy(tBrB, tBsB);
149+
__syncthreads();
150+
151+
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors.
152+
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
153+
copy_if(copy_a, tApA, tAgA(_,_,_,k_tile_next), tArA);
154+
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrBq);
155+
156+
// Compute gemm on mma-partitioned smem.
157+
gemm(mma, tCsA, tCsB, tCrC);
158+
}
159+
160+
copy(tCrC, tCgC);
161+
}
162+
163+
template <typename Element, typename GroupSize, typename F>
164+
inline auto dispatch_swizzle(F&& f) {
165+
if constexpr (sizeof(Element) == 4) {
166+
if constexpr (GroupSize::value <= 32) {
167+
f(Swizzle<3,2,3>{});
168+
} else {
169+
f(Swizzle<3,3,3>{});
170+
}
171+
} else {
172+
if constexpr (GroupSize::value <= 32) {
173+
f(Swizzle<2,3,3>{});
174+
} else {
175+
f(Swizzle<3,3,3>{});
176+
}
177+
}
178+
}
179+
180+
template <typename Element, typename F>
181+
inline auto dispatch_mma(bool is_sm80, F&& f) {
182+
if (is_sm80) {
183+
if constexpr (std::is_same_v<Element, float>) {
184+
f(make_tiled_mma(SM80_16x8x8_F32TF32TF32F32_TN{},
185+
Layout<Shape<_1,_4,_1>>{},
186+
Tile<_16,_32,_8>{}));
187+
return;
188+
} else if constexpr (std::is_same_v<Element, cute::half_t>) {
189+
f(make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
190+
Layout<Shape<_1,_4,_1>>{},
191+
Tile<_16,_32,_16>{}));
192+
return;
193+
}
194+
}
195+
f(make_tiled_mma(F32FMA<Element, Element>{},
196+
Layout<Shape<_16,_8,_1>>{}));
197+
}
198+
199+
template <typename GroupSize, typename Element, typename Quant, typename F>
200+
void qmm(
201+
int m, int n, int k, int l,
202+
GroupSize group_size,
203+
const Element* A,
204+
const Quant* B,
205+
const Element* S,
206+
const Element* Z,
207+
Element* C,
208+
bool is_sm80,
209+
F&& launch_kernel) {
210+
// Define shapes (dynamic).
211+
auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L)
212+
213+
// Define TN strides (mixed).
214+
auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL)
215+
auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL)
216+
auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL)
217+
218+
// Define layout of scales (mixed).
219+
auto S_layout = make_layout(
220+
make_shape(n, make_shape(group_size, k / group_size), l),
221+
make_stride(k / group_size, make_stride(Int<0>{}, Int<1>{}), n * k / group_size));
222+
223+
// Define CTA tile sizes (static).
224+
auto bM = Int<16>{};
225+
auto bN = Int<128>{};
226+
auto bK = Int<max(64,group_size)>{};
227+
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K)
228+
229+
TiledCopy copy_a = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, Element>{},
230+
Layout<Shape<_16,_8>,Stride<_8,_1>>{},
231+
Layout<Shape< _1,_8>>{});
232+
TiledCopy copy_b = make_tiled_copy(Copy_Atom<UniversalCopy<uint32_t>, Quant>{},
233+
Layout<Shape<_16,_8>,Stride<_8,_1>>{},
234+
Layout<Shape<_1,Int<32/sizeof_bits<Quant>::value>>>{});
235+
236+
// Define the smem layouts (static).
237+
dispatch_swizzle<Element, GroupSize>([&](auto swizzle) {
238+
auto swizzle_atom = composition(swizzle,
239+
Layout<Shape<_8,GroupSize>,
240+
Stride<GroupSize,_1>>{});
241+
auto sA_layout = tile_to_shape(swizzle_atom, make_shape(bM, bK));
242+
auto sB_layout = tile_to_shape(swizzle_atom, make_shape(bN, bK));
243+
244+
// Create tiled MMA.
245+
dispatch_mma<Element>(is_sm80, [&](auto mma) {
246+
// Launch kernel.
247+
auto* kernel = &qmm_impl<
248+
decltype(prob_shape), decltype(cta_tiler),
249+
Element, Quant,
250+
decltype(dA), decltype(sA_layout), decltype(copy_a),
251+
decltype(dB), decltype(sB_layout), decltype(copy_b),
252+
decltype(S_layout), decltype(dC), decltype(mma)>;
253+
dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l);
254+
dim3 block_dims(size(mma));
255+
void* args[] = {
256+
&prob_shape, &cta_tiler,
257+
&A, &dA, &sA_layout, &copy_a,
258+
&B, &dB, &sB_layout, &copy_b,
259+
&S, &Z, &S_layout,
260+
&C, &dC, &mma};
261+
launch_kernel(reinterpret_cast<void*>(kernel), num_blocks, block_dims, 0, args);
262+
});
263+
});
264+
}
265+
266+
} // namespace cute_qmm
267+
268+
// clang-format on
269+
270+
namespace mlx::core {
271+
272+
template <typename F>
273+
inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
274+
if (dtype == float32) {
275+
f.template operator()<float>();
276+
} else if (dtype == float16) {
277+
f.template operator()<cutlass::half_t>();
278+
} else {
279+
throw std::invalid_argument(
280+
fmt::format(
281+
"[{0}] Unsupported dtype: {1}.", tag, dtype_to_string(dtype)));
282+
}
283+
}
284+
285+
template <typename F>
286+
inline void dispatch_quant_types(int bits, const char* tag, F&& f) {
287+
if (bits == 8) {
288+
f.template operator()<uint8_t>();
289+
} else {
290+
throw std::invalid_argument(
291+
fmt::format("[{0}] {1}-bit quantization is not supported.", tag, bits));
292+
}
293+
}
294+
295+
template <typename F>
296+
inline void dispatch_groups(int group_size, const char* tag, F&& f) {
297+
if (group_size == 16) {
298+
f(cute::Int<16>{});
299+
} else if (group_size == 32) {
300+
f(cute::Int<32>{});
301+
} else if (group_size == 64) {
302+
f(cute::Int<64>{});
303+
} else {
304+
throw std::invalid_argument(
305+
fmt::format("[{0}] Group size {1} is not supported.", tag, group_size));
306+
}
307+
}
308+
309+
void cute_qmm(
310+
const array& x,
311+
const array& w,
312+
const array& scales,
313+
const array& biases,
314+
array& out,
315+
int bits,
316+
int group_size,
317+
cu::CommandEncoder& encoder) {
318+
const char* tag = "[quantized_matmul]";
319+
int m = out.shape(-2);
320+
int n = out.shape(-1);
321+
int k = x.shape(-1);
322+
int l = out.size() / (m * n);
323+
if (n % 128 != 0) {
324+
throw std::runtime_error(
325+
fmt::format("[{0}] N must be multiples of 128.", tag));
326+
}
327+
if (k % 64 != 0) {
328+
throw std::runtime_error(
329+
fmt::format("[{0}] K must be multiples of 64.", tag));
330+
}
331+
if (l != 1 && m % 16 != 0) {
332+
throw std::runtime_error(
333+
fmt::format("[{0}] M must be multiples of 16 for batched GEMM.", tag));
334+
}
335+
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
336+
dispatch_quant_types(bits, tag, [&]<typename Quant>() {
337+
dispatch_groups(group_size, tag, [&](auto group_size) {
338+
encoder.set_input_array(x);
339+
encoder.set_input_array(w);
340+
encoder.set_input_array(scales);
341+
encoder.set_input_array(biases);
342+
encoder.set_output_array(out);
343+
cute_gemm::qmm(
344+
m,
345+
n,
346+
k,
347+
l,
348+
group_size,
349+
gpu_ptr<Element>(x),
350+
gpu_ptr<Quant>(w),
351+
gpu_ptr<Element>(scales),
352+
gpu_ptr<Element>(biases),
353+
gpu_ptr<Element>(out),
354+
encoder.device().compute_capability_major() >= 8,
355+
[&](auto* kernel,
356+
dim3 num_blocks,
357+
dim3 block_dims,
358+
uint32_t smem_bytes,
359+
void** args) {
360+
encoder.add_kernel_node(
361+
kernel, num_blocks, block_dims, smem_bytes, args);
362+
});
363+
});
364+
});
365+
});
366+
}
367+
368+
} // namespace mlx::core

mlx/backend/cuda/quantized/qmm.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#pragma once
4+
5+
#include "mlx/backend/cuda/device.h"
6+
7+
namespace mlx::core {
8+
9+
void cute_qmm(
10+
const array& x,
11+
const array& w,
12+
const array& scales,
13+
const array& biases,
14+
array& out,
15+
int bits,
16+
int group_size,
17+
cu::CommandEncoder& encoder);
18+
19+
} // namespace mlx::core

0 commit comments

Comments
 (0)