@@ -117,9 +117,9 @@ namespace cg = cooperative_groups;
117117// Fused vectorized dequantize and multiply-add:
118118// w_dq = w * scale + bias
119119// out = fma(x, w_dq, out)
120- template <int N, typename T, typename Q>
120+ template <int N, bool has_bias, typename T, typename Q, typename S >
121121__device__ __forceinline__ void
122- dequant_fma (const T* x, const Q* w, T scale, T bias, T* out) {
122+ dequant_fma (const T* x, const Q* w, S scale, T bias, T* out) {
123123 // Read x/w into registers.
124124 auto x_vec = *(reinterpret_cast <const cutlass::Array<T, N>*>(x));
125125 auto w_vec = *(reinterpret_cast <const cutlass::Array<Q, N>*>(w));
@@ -129,13 +129,17 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
129129 // Dequantize w.
130130 cutlass::NumericArrayConverter<T, Q, N> converter_tq;
131131 cutlass::Array<T, N> w_dq = converter_tq (w_vec);
132- if constexpr (cuda::std::is_same_v<T, float >) {
132+ if constexpr (has_bias) {
133+ if constexpr (cuda::std::is_same_v<T, float >) {
133134#pragma unroll
134- for (int i = 0 ; i < N; ++i) {
135- w_dq[i] = w_dq[i] * scale + bias;
135+ for (int i = 0 ; i < N; ++i) {
136+ w_dq[i] = w_dq[i] * T (scale) + bias;
137+ }
138+ } else {
139+ w_dq = w_dq * T (scale) + bias;
136140 }
137141 } else {
138- w_dq = w_dq * scale + bias ;
142+ w_dq = w_dq * T ( scale) ;
139143 }
140144
141145 // Multiply and add.
@@ -145,11 +149,13 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
145149// Specialization for doing float32 accumulations on narrow types.
146150template <
147151 int N,
152+ bool has_bias,
148153 typename T,
149154 typename Q,
155+ typename S,
150156 typename = cuda::std::enable_if_t <!cuda::std::is_same_v<T, float >>>
151157__device__ __forceinline__ void
152- dequant_fma (const T* x, const Q* w, T scale, T bias, float * out) {
158+ dequant_fma (const T* x, const Q* w, S scale, T bias, float * out) {
153159 // Read x/w into registers.
154160 auto x_vec = *(reinterpret_cast <const cutlass::Array<T, N>*>(x));
155161 auto w_vec = *(reinterpret_cast <const cutlass::Array<Q, N>*>(w));
@@ -159,7 +165,11 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
159165 // Dequantize w.
160166 cutlass::NumericArrayConverter<T, Q, N> converter_tq;
161167 cutlass::Array<T, N> w_dq = converter_tq (w_vec);
162- w_dq = w_dq * scale + bias;
168+ if constexpr (has_bias) {
169+ w_dq = w_dq * T (scale) + bias;
170+ } else {
171+ w_dq = w_dq * T (scale);
172+ }
163173
164174 // Promote x/w to float.
165175 static_assert (!cuda::std::is_same_v<T, float >);
@@ -178,11 +188,12 @@ template <
178188 bool has_bias,
179189 bool has_residue_k,
180190 typename T,
181- typename Q>
191+ typename Q,
192+ typename S>
182193__global__ void qmv_kernel (
183194 const T* x,
184195 const Q* w,
185- const T * scales,
196+ const S * scales,
186197 const T* biases,
187198 T* out,
188199 int n,
@@ -224,12 +235,13 @@ __global__ void qmv_kernel(
224235 cuda::std::conditional_t <(bits >= 8 ), float , T> sums[elems_per_thread] = {};
225236
226237 auto dequant_fma_tile = [&](int idx) {
227- T scale = scales[idx / group_size];
238+ S scale = scales[idx / group_size];
228239 T bias{0 };
229240 if constexpr (has_bias) {
230241 bias = biases[idx / group_size];
231242 }
232- dequant_fma<elems_per_thread>(x + idx, w + w_step (idx), scale, bias, sums);
243+ dequant_fma<elems_per_thread, has_bias>(
244+ x + idx, w + w_step (idx), scale, bias, sums);
233245 };
234246
235247 // Loop over k dimension.
@@ -262,11 +274,17 @@ __global__ void qmv_kernel(
262274 }
263275}
264276
265- template <int group_size, bool has_bias, typename T, typename Q, typename F>
277+ template <
278+ int group_size,
279+ bool has_bias,
280+ typename T,
281+ typename Q,
282+ typename S,
283+ typename F>
266284void qmv (
267285 const T* x,
268286 const Q* w,
269- const T * scales,
287+ const S * scales,
270288 const T* biases,
271289 T* out,
272290 int m,
@@ -292,7 +310,8 @@ void qmv(
292310 has_bias,
293311 has_residue_k.value ,
294312 T,
295- Q>;
313+ Q,
314+ S>;
296315 launch_kernel (
297316 reinterpret_cast <void *>(kernel), num_blocks, block_dims, args);
298317 });
@@ -328,33 +347,33 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
328347 }
329348}
330349
331- template <typename F>
350+ template <typename T, typename F>
332351inline void dispatch_quant_types (
333352 int bits,
334353 int group_size,
335354 QuantizationMode mode,
336355 const char * tag,
337356 F&& f) {
338357 if (mode == QuantizationMode::Mxfp4) {
339- f.template operator ()<cutlass::float_e2m1_t , 16 >();
358+ f.template operator ()<cutlass::float_e2m1_t , cutlass:: float_ue8m0_t , 32 >();
340359 } else if (mode == QuantizationMode::Mxfp8) {
341- f.template operator ()<cutlass::float_e4m3_t , 32 >();
360+ f.template operator ()<cutlass::float_e4m3_t , cutlass:: float_ue8m0_t , 32 >();
342361 } else if (mode == QuantizationMode::Nvfp4) {
343- f.template operator ()<cutlass::float_e2m1_t , 32 >();
362+ f.template operator ()<cutlass::float_e2m1_t , cutlass:: float_e4m3_t , 16 >();
344363 } else {
345364 dispatch_groups (group_size, tag, [&]<int group_size>() {
346365 if (bits == 2 ) {
347- f.template operator ()<cutlass::uint2b_t , group_size>();
366+ f.template operator ()<cutlass::uint2b_t , T, group_size>();
348367 } else if (bits == 3 ) {
349- f.template operator ()<cutlass::uint3b_t , group_size>();
368+ f.template operator ()<cutlass::uint3b_t , T, group_size>();
350369 } else if (bits == 4 ) {
351- f.template operator ()<cutlass::uint4b_t , group_size>();
370+ f.template operator ()<cutlass::uint4b_t , T, group_size>();
352371 } else if (bits == 5 ) {
353- f.template operator ()<cutlass::uint5b_t , group_size>();
372+ f.template operator ()<cutlass::uint5b_t , T, group_size>();
354373 } else if (bits == 6 ) {
355- f.template operator ()<cutlass::uint6b_t , group_size>();
374+ f.template operator ()<cutlass::uint6b_t , T, group_size>();
356375 } else if (bits == 8 ) {
357- f.template operator ()<uint8_t , group_size>();
376+ f.template operator ()<uint8_t , T, group_size>();
358377 } else {
359378 throw std::invalid_argument (
360379 fmt::format (" {} {}-bit quantization is not supported." , tag, bits));
@@ -381,8 +400,12 @@ void qmv(
381400 bool broadcast_w = w.ndim () == 2 ;
382401
383402 dispatch_element_types (out.dtype (), tag, [&]<typename T>() {
384- dispatch_quant_types (
385- bits, group_size, mode, tag, [&]<typename Q, int group_size>() {
403+ dispatch_quant_types<T>(
404+ bits,
405+ group_size,
406+ mode,
407+ tag,
408+ [&]<typename Q, typename S, int group_size>() {
386409 encoder.set_input_array (x);
387410 encoder.set_input_array (w);
388411 encoder.set_input_array (scales);
@@ -394,7 +417,7 @@ void qmv(
394417 cu::qmv<group_size, has_bias>(
395418 gpu_ptr<T>(x),
396419 gpu_ptr<Q>(w),
397- gpu_ptr<T >(scales),
420+ gpu_ptr<S >(scales),
398421 biases ? gpu_ptr<T>(*biases) : nullptr ,
399422 gpu_ptr<T>(out),
400423 m,
0 commit comments