@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
192192}
193193
194194void _qmm_dispatch (
195- array out,
195+ array& out,
196196 const array& x,
197197 const array& w,
198198 const array& scales,
@@ -253,6 +253,81 @@ void _qmm_dispatch(
253253 }
254254}
255255
256+ void _bs_qmm_dispatch (
257+ array& out,
258+ const array& x,
259+ const array& w,
260+ const array& scales,
261+ const array& biases,
262+ const array& lhs_indices,
263+ const array& rhs_indices,
264+ int bits,
265+ int group_size,
266+ bool transposed_w) {
267+ int K = x.shape (-1 );
268+ int M = x.shape (-2 );
269+ int N = out.shape (-1 );
270+
271+ int w_els = w.shape (-1 ) * w.shape (-2 );
272+ int g_els = scales.shape (-1 ) * scales.shape (-2 );
273+
274+ const uint32_t * lhs_indices_data = lhs_indices.data <uint32_t >();
275+ const uint32_t * rhs_indices_data = rhs_indices.data <uint32_t >();
276+
277+ for (int i = 0 ; i < lhs_indices.size (); i++) {
278+ int x_idx = lhs_indices_data[elem_to_loc (i, lhs_indices)];
279+ int w_idx = rhs_indices_data[elem_to_loc (i, rhs_indices)];
280+
281+ switch (x.dtype ()) {
282+ case float32:
283+ _qmm_dispatch_typed<float >(
284+ out.data <float >() + i * M * N,
285+ x.data <float >() + elem_to_loc (x_idx * M * K, x),
286+ w.data <uint32_t >() + elem_to_loc (w_idx * w_els, w),
287+ scales.data <float >() + elem_to_loc (w_idx * g_els, scales),
288+ biases.data <float >() + elem_to_loc (w_idx * g_els, biases),
289+ M,
290+ N,
291+ K,
292+ bits,
293+ group_size,
294+ transposed_w);
295+ break ;
296+ case float16:
297+ _qmm_dispatch_typed<float16_t >(
298+ out.data <float16_t >() + i * M * N,
299+ x.data <float16_t >() + elem_to_loc (x_idx * M * K, x),
300+ w.data <uint32_t >() + elem_to_loc (w_idx * w_els, w),
301+ scales.data <float16_t >() + elem_to_loc (w_idx * g_els, scales),
302+ biases.data <float16_t >() + elem_to_loc (w_idx * g_els, biases),
303+ M,
304+ N,
305+ K,
306+ bits,
307+ group_size,
308+ transposed_w);
309+ break ;
310+ case bfloat16:
311+ _qmm_dispatch_typed<bfloat16_t >(
312+ out.data <bfloat16_t >() + i * M * N,
313+ x.data <bfloat16_t >() + elem_to_loc (x_idx * M * K, x),
314+ w.data <uint32_t >() + elem_to_loc (w_idx * w_els, w),
315+ scales.data <bfloat16_t >() + elem_to_loc (w_idx * g_els, scales),
316+ biases.data <bfloat16_t >() + elem_to_loc (w_idx * g_els, biases),
317+ M,
318+ N,
319+ K,
320+ bits,
321+ group_size,
322+ transposed_w);
323+ break ;
324+ default :
325+ throw std::invalid_argument (
326+ " [quantized_matmul] only floating types are supported" );
327+ }
328+ }
329+ }
330+
256331} // namespace
257332
258333void QuantizedMatmul::eval (const std::vector<array>& inputs, array& out) {
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
282357 _qmm_dispatch (out, x, w, scales, biases, group_size_, bits_, transpose_);
283358}
284359
360+ void BlockSparseQMM::eval (const std::vector<array>& inputs, array& out) {
361+ assert (inputs.size () == 6 );
362+
363+ auto & x_pre = inputs[0 ];
364+ auto & w_pre = inputs[1 ];
365+ auto & scales_pre = inputs[2 ];
366+ auto & biases_pre = inputs[3 ];
367+ auto & lhs_indices = inputs[4 ];
368+ auto & rhs_indices = inputs[5 ];
369+
370+ auto ensure_row_contiguous_last_dims = [](const array& arr) {
371+ auto stride_0 = arr.strides ()[arr.ndim () - 2 ];
372+ auto stride_1 = arr.strides ()[arr.ndim () - 1 ];
373+ if (stride_0 == arr.shape (-1 ) && stride_1 == 1 ) {
374+ return arr;
375+ } else {
376+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
377+ copy (arr, arr_copy, CopyType::General);
378+ return arr_copy;
379+ }
380+ };
381+
382+ auto x = ensure_row_contiguous_last_dims (x_pre);
383+ auto w = ensure_row_contiguous_last_dims (w_pre);
384+ auto scales = ensure_row_contiguous_last_dims (scales_pre);
385+ auto biases = ensure_row_contiguous_last_dims (biases_pre);
386+
387+ out.set_data (allocator::malloc_or_wait (out.nbytes ()));
388+ _bs_qmm_dispatch (
389+ out,
390+ x,
391+ w,
392+ scales,
393+ biases,
394+ lhs_indices,
395+ rhs_indices,
396+ group_size_,
397+ bits_,
398+ transpose_);
399+ }
400+
285401} // namespace mlx::core
0 commit comments