@@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
190190 }
191191}
192192
193+ void BlockSparseMM::eval (const std::vector<array>& inputs, array& out) {
194+ if (out.dtype () != float32) {
195+ throw std::runtime_error (
196+ " [BlockSparseMM::eval] Currently only supports float32." );
197+ }
198+ out.set_data (allocator::malloc_or_wait (out.nbytes ()));
199+
200+ auto & a_pre = inputs[0 ];
201+ auto & b_pre = inputs[1 ];
202+
203+ auto check_transpose = [](const array& arr) {
204+ auto stx = arr.strides ()[arr.ndim () - 2 ];
205+ auto sty = arr.strides ()[arr.ndim () - 1 ];
206+ if (stx == arr.shape (-1 ) && sty == 1 ) {
207+ return std::make_tuple (false , stx, arr);
208+ } else if (stx == 1 && sty == arr.shape (-2 )) {
209+ return std::make_tuple (true , sty, arr);
210+ } else {
211+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
212+ copy (arr, arr_copy, CopyType::General);
213+ size_t stx = arr.shape (-1 );
214+ return std::make_tuple (false , stx, arr_copy);
215+ }
216+ };
217+
218+ auto [a_transposed, lda, a] = check_transpose (a_pre);
219+ auto [b_transposed, ldb, b] = check_transpose (b_pre);
220+
221+ size_t M = a.shape (-2 );
222+ size_t N = b.shape (-1 );
223+ size_t K = a.shape (-1 );
224+
225+ if (M == 0 || N == 0 ) {
226+ return ;
227+ }
228+
229+ if (K == 0 ) {
230+ std::memset (static_cast <void *>(out.data <float >()), 0 , out.nbytes ());
231+ return ;
232+ }
233+
234+ // Get batch dims
235+ auto batch_size_out = out.size () / (M * N);
236+ size_t matrix_stride_out = M * N;
237+
238+ auto get_batch_dims = [](const auto & v) {
239+ return decltype (v){v.begin (), v.end () - 2 };
240+ };
241+
242+ auto & lhs_indices = inputs[2 ];
243+ auto & rhs_indices = inputs[3 ];
244+
245+ std::vector<int > batch_shape = get_batch_dims (out.shape ());
246+ int batch_ndim = batch_shape.size ();
247+
248+ std::vector<int > batch_shape_A = get_batch_dims (a.shape ());
249+ std::vector<size_t > batch_strides_A = get_batch_dims (a.strides ());
250+ std::vector<int > batch_shape_B = get_batch_dims (b.shape ());
251+ std::vector<size_t > batch_strides_B = get_batch_dims (b.strides ());
252+
253+ const uint32_t * lhs_indices_ptr = lhs_indices.data <uint32_t >();
254+ const uint32_t * rhs_indices_ptr = rhs_indices.data <uint32_t >();
255+
256+ for (int i = 0 ; i < batch_size_out; i++) {
257+ // Get index
258+ uint32_t indx_A = lhs_indices_ptr[elem_to_loc (i, lhs_indices)];
259+ uint32_t indx_B = rhs_indices_ptr[elem_to_loc (i, rhs_indices)];
260+
261+ cblas_sgemm (
262+ CblasRowMajor,
263+ a_transposed ? CblasTrans : CblasNoTrans, // transA
264+ b_transposed ? CblasTrans : CblasNoTrans, // transB
265+ M,
266+ N,
267+ K,
268+ 1 .0f , // alpha
269+ a.data <float >() + elem_to_loc (indx_A, batch_shape_A, batch_strides_A),
270+ lda,
271+ b.data <float >() + elem_to_loc (indx_B, batch_shape_B, batch_strides_B),
272+ ldb,
273+ 0 .0f , // beta
274+ out.data <float >() + matrix_stride_out * i,
275+ out.shape (-1 ) // ldc
276+ );
277+ }
278+ }
279+
193280} // namespace mlx::core
0 commit comments