Skip to content

Commit e78a651

Browse files
authored
Block sparse qmm (#1124)
1 parent 1873ffd commit e78a651

File tree

15 files changed

+1724
-164
lines changed

15 files changed

+1724
-164
lines changed

mlx/backend/accelerate/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ DEFAULT(ArgSort)
3333
DEFAULT(AsStrided)
3434
DEFAULT(BlockMaskedMM)
3535
DEFAULT(BlockSparseMM)
36+
DEFAULT(BlockSparseQMM)
3637
DEFAULT(Broadcast)
3738
DEFAULT(Ceil)
3839
DEFAULT(Concatenate)

mlx/backend/common/default_primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ DEFAULT(AsStrided)
4444
DEFAULT(Broadcast)
4545
DEFAULT(BlockMaskedMM)
4646
DEFAULT(BlockSparseMM)
47+
DEFAULT(BlockSparseQMM)
4748
DEFAULT_MULTI(DivMod)
4849
DEFAULT(Ceil)
4950
DEFAULT(Concatenate)

mlx/backend/common/quantized.cpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
192192
}
193193

194194
void _qmm_dispatch(
195-
array out,
195+
array& out,
196196
const array& x,
197197
const array& w,
198198
const array& scales,
@@ -253,6 +253,81 @@ void _qmm_dispatch(
253253
}
254254
}
255255

256+
void _bs_qmm_dispatch(
257+
array& out,
258+
const array& x,
259+
const array& w,
260+
const array& scales,
261+
const array& biases,
262+
const array& lhs_indices,
263+
const array& rhs_indices,
264+
int bits,
265+
int group_size,
266+
bool transposed_w) {
267+
int K = x.shape(-1);
268+
int M = x.shape(-2);
269+
int N = out.shape(-1);
270+
271+
int w_els = w.shape(-1) * w.shape(-2);
272+
int g_els = scales.shape(-1) * scales.shape(-2);
273+
274+
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
275+
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
276+
277+
for (int i = 0; i < lhs_indices.size(); i++) {
278+
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
279+
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
280+
281+
switch (x.dtype()) {
282+
case float32:
283+
_qmm_dispatch_typed<float>(
284+
out.data<float>() + i * M * N,
285+
x.data<float>() + elem_to_loc(x_idx * M * K, x),
286+
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
287+
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
288+
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
289+
M,
290+
N,
291+
K,
292+
bits,
293+
group_size,
294+
transposed_w);
295+
break;
296+
case float16:
297+
_qmm_dispatch_typed<float16_t>(
298+
out.data<float16_t>() + i * M * N,
299+
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
300+
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
301+
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
302+
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
303+
M,
304+
N,
305+
K,
306+
bits,
307+
group_size,
308+
transposed_w);
309+
break;
310+
case bfloat16:
311+
_qmm_dispatch_typed<bfloat16_t>(
312+
out.data<bfloat16_t>() + i * M * N,
313+
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
314+
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
315+
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
316+
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
317+
M,
318+
N,
319+
K,
320+
bits,
321+
group_size,
322+
transposed_w);
323+
break;
324+
default:
325+
throw std::invalid_argument(
326+
"[quantized_matmul] only floating types are supported");
327+
}
328+
}
329+
}
330+
256331
} // namespace
257332

258333
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
282357
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
283358
}
284359

360+
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
361+
assert(inputs.size() == 6);
362+
363+
auto& x_pre = inputs[0];
364+
auto& w_pre = inputs[1];
365+
auto& scales_pre = inputs[2];
366+
auto& biases_pre = inputs[3];
367+
auto& lhs_indices = inputs[4];
368+
auto& rhs_indices = inputs[5];
369+
370+
auto ensure_row_contiguous_last_dims = [](const array& arr) {
371+
auto stride_0 = arr.strides()[arr.ndim() - 2];
372+
auto stride_1 = arr.strides()[arr.ndim() - 1];
373+
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
374+
return arr;
375+
} else {
376+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
377+
copy(arr, arr_copy, CopyType::General);
378+
return arr_copy;
379+
}
380+
};
381+
382+
auto x = ensure_row_contiguous_last_dims(x_pre);
383+
auto w = ensure_row_contiguous_last_dims(w_pre);
384+
auto scales = ensure_row_contiguous_last_dims(scales_pre);
385+
auto biases = ensure_row_contiguous_last_dims(biases_pre);
386+
387+
out.set_data(allocator::malloc_or_wait(out.nbytes()));
388+
_bs_qmm_dispatch(
389+
out,
390+
x,
391+
w,
392+
scales,
393+
biases,
394+
lhs_indices,
395+
rhs_indices,
396+
group_size_,
397+
bits_,
398+
transpose_);
399+
}
400+
285401
} // namespace mlx::core

0 commit comments

Comments
 (0)