Skip to content

Commit ce45c52

Browse files
authored
[CUDA] Use qmv kernel for fp quantizations (#3239)
1 parent 0879a6a commit ce45c52

3 files changed

Lines changed: 57 additions & 32 deletions

File tree

mlx/backend/cuda/quantized/qmm/qmm.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ bool supports_fp_qmv(
108108
int group_size,
109109
QuantizationMode mode,
110110
cu::Device& device) {
111+
// The fp_qmv kernel uses less registers and is faster for sm120. For sm80/90
112+
// the qmv kernel is faster. We didn't test sm89/100.
113+
if (device.compute_capability_major() <= 9) {
114+
return false;
115+
}
111116
bool non_batched = w.ndim() == 2;
112117
int k = x.shape(-1);
113118
int n = out.shape(-1);
@@ -149,9 +154,6 @@ bool supports_qmv(
149154
if (!transpose) {
150155
return false;
151156
}
152-
if (mode != QuantizationMode::Affine) {
153-
return false;
154-
}
155157
return true;
156158
}
157159

mlx/backend/cuda/quantized/qmm/qmv.cu

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ namespace cg = cooperative_groups;
117117
// Fused vectorized dequantize and multiply-add:
118118
// w_dq = w * scale + bias
119119
// out = fma(x, w_dq, out)
120-
template <int N, typename T, typename Q>
120+
template <int N, bool has_bias, typename T, typename Q, typename S>
121121
__device__ __forceinline__ void
122-
dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
122+
dequant_fma(const T* x, const Q* w, S scale, T bias, T* out) {
123123
// Read x/w into registers.
124124
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
125125
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
@@ -129,13 +129,17 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
129129
// Dequantize w.
130130
cutlass::NumericArrayConverter<T, Q, N> converter_tq;
131131
cutlass::Array<T, N> w_dq = converter_tq(w_vec);
132-
if constexpr (cuda::std::is_same_v<T, float>) {
132+
if constexpr (has_bias) {
133+
if constexpr (cuda::std::is_same_v<T, float>) {
133134
#pragma unroll
134-
for (int i = 0; i < N; ++i) {
135-
w_dq[i] = w_dq[i] * scale + bias;
135+
for (int i = 0; i < N; ++i) {
136+
w_dq[i] = w_dq[i] * T(scale) + bias;
137+
}
138+
} else {
139+
w_dq = w_dq * T(scale) + bias;
136140
}
137141
} else {
138-
w_dq = w_dq * scale + bias;
142+
w_dq = w_dq * T(scale);
139143
}
140144

141145
// Multiply and add.
@@ -145,11 +149,13 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
145149
// Specialization for doing float32 accumulations on narrow types.
146150
template <
147151
int N,
152+
bool has_bias,
148153
typename T,
149154
typename Q,
155+
typename S,
150156
typename = cuda::std::enable_if_t<!cuda::std::is_same_v<T, float>>>
151157
__device__ __forceinline__ void
152-
dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
158+
dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) {
153159
// Read x/w into registers.
154160
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
155161
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
@@ -159,7 +165,11 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
159165
// Dequantize w.
160166
cutlass::NumericArrayConverter<T, Q, N> converter_tq;
161167
cutlass::Array<T, N> w_dq = converter_tq(w_vec);
162-
w_dq = w_dq * scale + bias;
168+
if constexpr (has_bias) {
169+
w_dq = w_dq * T(scale) + bias;
170+
} else {
171+
w_dq = w_dq * T(scale);
172+
}
163173

164174
// Promote x/w to float.
165175
static_assert(!cuda::std::is_same_v<T, float>);
@@ -178,11 +188,12 @@ template <
178188
bool has_bias,
179189
bool has_residue_k,
180190
typename T,
181-
typename Q>
191+
typename Q,
192+
typename S>
182193
__global__ void qmv_kernel(
183194
const T* x,
184195
const Q* w,
185-
const T* scales,
196+
const S* scales,
186197
const T* biases,
187198
T* out,
188199
int n,
@@ -224,12 +235,13 @@ __global__ void qmv_kernel(
224235
cuda::std::conditional_t<(bits >= 8), float, T> sums[elems_per_thread] = {};
225236

226237
auto dequant_fma_tile = [&](int idx) {
227-
T scale = scales[idx / group_size];
238+
S scale = scales[idx / group_size];
228239
T bias{0};
229240
if constexpr (has_bias) {
230241
bias = biases[idx / group_size];
231242
}
232-
dequant_fma<elems_per_thread>(x + idx, w + w_step(idx), scale, bias, sums);
243+
dequant_fma<elems_per_thread, has_bias>(
244+
x + idx, w + w_step(idx), scale, bias, sums);
233245
};
234246

235247
// Loop over k dimension.
@@ -262,11 +274,17 @@ __global__ void qmv_kernel(
262274
}
263275
}
264276

265-
template <int group_size, bool has_bias, typename T, typename Q, typename F>
277+
template <
278+
int group_size,
279+
bool has_bias,
280+
typename T,
281+
typename Q,
282+
typename S,
283+
typename F>
266284
void qmv(
267285
const T* x,
268286
const Q* w,
269-
const T* scales,
287+
const S* scales,
270288
const T* biases,
271289
T* out,
272290
int m,
@@ -292,7 +310,8 @@ void qmv(
292310
has_bias,
293311
has_residue_k.value,
294312
T,
295-
Q>;
313+
Q,
314+
S>;
296315
launch_kernel(
297316
reinterpret_cast<void*>(kernel), num_blocks, block_dims, args);
298317
});
@@ -328,33 +347,33 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
328347
}
329348
}
330349

331-
template <typename F>
350+
template <typename T, typename F>
332351
inline void dispatch_quant_types(
333352
int bits,
334353
int group_size,
335354
QuantizationMode mode,
336355
const char* tag,
337356
F&& f) {
338357
if (mode == QuantizationMode::Mxfp4) {
339-
f.template operator()<cutlass::float_e2m1_t, 16>();
358+
f.template operator()<cutlass::float_e2m1_t, cutlass::float_ue8m0_t, 32>();
340359
} else if (mode == QuantizationMode::Mxfp8) {
341-
f.template operator()<cutlass::float_e4m3_t, 32>();
360+
f.template operator()<cutlass::float_e4m3_t, cutlass::float_ue8m0_t, 32>();
342361
} else if (mode == QuantizationMode::Nvfp4) {
343-
f.template operator()<cutlass::float_e2m1_t, 32>();
362+
f.template operator()<cutlass::float_e2m1_t, cutlass::float_e4m3_t, 16>();
344363
} else {
345364
dispatch_groups(group_size, tag, [&]<int group_size>() {
346365
if (bits == 2) {
347-
f.template operator()<cutlass::uint2b_t, group_size>();
366+
f.template operator()<cutlass::uint2b_t, T, group_size>();
348367
} else if (bits == 3) {
349-
f.template operator()<cutlass::uint3b_t, group_size>();
368+
f.template operator()<cutlass::uint3b_t, T, group_size>();
350369
} else if (bits == 4) {
351-
f.template operator()<cutlass::uint4b_t, group_size>();
370+
f.template operator()<cutlass::uint4b_t, T, group_size>();
352371
} else if (bits == 5) {
353-
f.template operator()<cutlass::uint5b_t, group_size>();
372+
f.template operator()<cutlass::uint5b_t, T, group_size>();
354373
} else if (bits == 6) {
355-
f.template operator()<cutlass::uint6b_t, group_size>();
374+
f.template operator()<cutlass::uint6b_t, T, group_size>();
356375
} else if (bits == 8) {
357-
f.template operator()<uint8_t, group_size>();
376+
f.template operator()<uint8_t, T, group_size>();
358377
} else {
359378
throw std::invalid_argument(
360379
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
@@ -381,8 +400,12 @@ void qmv(
381400
bool broadcast_w = w.ndim() == 2;
382401

383402
dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
384-
dispatch_quant_types(
385-
bits, group_size, mode, tag, [&]<typename Q, int group_size>() {
403+
dispatch_quant_types<T>(
404+
bits,
405+
group_size,
406+
mode,
407+
tag,
408+
[&]<typename Q, typename S, int group_size>() {
386409
encoder.set_input_array(x);
387410
encoder.set_input_array(w);
388411
encoder.set_input_array(scales);
@@ -394,7 +417,7 @@ void qmv(
394417
cu::qmv<group_size, has_bias>(
395418
gpu_ptr<T>(x),
396419
gpu_ptr<Q>(w),
397-
gpu_ptr<T>(scales),
420+
gpu_ptr<S>(scales),
398421
biases ? gpu_ptr<T>(*biases) : nullptr,
399422
gpu_ptr<T>(out),
400423
m,

mlx/backend/cuda/quantized/qqmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
108108

109109
const array& w = inputs[1];
110110
const array& scales = inputs[2];
111-
fp_qmv(xhat, w, scales, out, bits_, group_size_, encoder, s);
111+
qmv(xhat, w, scales, std::nullopt, out, bits_, group_size_, mode_, encoder);
112112
return;
113113
}
114114

0 commit comments

Comments
 (0)