Skip to content

Commit f390957

Browse files
authored
Block sparse mm (#1058)
1 parent 17f57df commit f390957

File tree

15 files changed

+1323
-75
lines changed

15 files changed

+1323
-75
lines changed

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Operations
3232
bitwise_or
3333
bitwise_xor
3434
block_masked_mm
35+
block_sparse_mm
3536
broadcast_to
3637
ceil
3738
clip

mlx/backend/accelerate/primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ DEFAULT(ArgReduce)
3232
DEFAULT(ArgSort)
3333
DEFAULT(AsStrided)
3434
DEFAULT(BlockMaskedMM)
35+
DEFAULT(BlockSparseMM)
3536
DEFAULT(Broadcast)
3637
DEFAULT(Ceil)
3738
DEFAULT(Concatenate)

mlx/backend/common/default_primitives.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ DEFAULT(AsType)
4242
DEFAULT(AsStrided)
4343
DEFAULT(Broadcast)
4444
DEFAULT(BlockMaskedMM)
45+
DEFAULT(BlockSparseMM)
4546
DEFAULT_MULTI(DivMod)
4647
DEFAULT(Ceil)
4748
DEFAULT(Concatenate)

mlx/backend/common/masked_mm.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
190190
}
191191
}
192192

193+
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
194+
if (out.dtype() != float32) {
195+
throw std::runtime_error(
196+
"[BlockSparseMM::eval] Currently only supports float32.");
197+
}
198+
out.set_data(allocator::malloc_or_wait(out.nbytes()));
199+
200+
auto& a_pre = inputs[0];
201+
auto& b_pre = inputs[1];
202+
203+
auto check_transpose = [](const array& arr) {
204+
auto stx = arr.strides()[arr.ndim() - 2];
205+
auto sty = arr.strides()[arr.ndim() - 1];
206+
if (stx == arr.shape(-1) && sty == 1) {
207+
return std::make_tuple(false, stx, arr);
208+
} else if (stx == 1 && sty == arr.shape(-2)) {
209+
return std::make_tuple(true, sty, arr);
210+
} else {
211+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
212+
copy(arr, arr_copy, CopyType::General);
213+
size_t stx = arr.shape(-1);
214+
return std::make_tuple(false, stx, arr_copy);
215+
}
216+
};
217+
218+
auto [a_transposed, lda, a] = check_transpose(a_pre);
219+
auto [b_transposed, ldb, b] = check_transpose(b_pre);
220+
221+
size_t M = a.shape(-2);
222+
size_t N = b.shape(-1);
223+
size_t K = a.shape(-1);
224+
225+
if (M == 0 || N == 0) {
226+
return;
227+
}
228+
229+
if (K == 0) {
230+
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
231+
return;
232+
}
233+
234+
// Get batch dims
235+
auto batch_size_out = out.size() / (M * N);
236+
size_t matrix_stride_out = M * N;
237+
238+
auto get_batch_dims = [](const auto& v) {
239+
return decltype(v){v.begin(), v.end() - 2};
240+
};
241+
242+
auto& lhs_indices = inputs[2];
243+
auto& rhs_indices = inputs[3];
244+
245+
std::vector<int> batch_shape = get_batch_dims(out.shape());
246+
int batch_ndim = batch_shape.size();
247+
248+
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
249+
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
250+
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
251+
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
252+
253+
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
254+
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
255+
256+
for (int i = 0; i < batch_size_out; i++) {
257+
// Get index
258+
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
259+
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
260+
261+
cblas_sgemm(
262+
CblasRowMajor,
263+
a_transposed ? CblasTrans : CblasNoTrans, // transA
264+
b_transposed ? CblasTrans : CblasNoTrans, // transB
265+
M,
266+
N,
267+
K,
268+
1.0f, // alpha
269+
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
270+
lda,
271+
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
272+
ldb,
273+
0.0f, // beta
274+
out.data<float>() + matrix_stride_out * i,
275+
out.shape(-1) // ldc
276+
);
277+
}
278+
}
279+
193280
} // namespace mlx::core

0 commit comments

Comments
 (0)