@@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
197197 std::vector<int > B_bshape{b.shape ().begin (), b.shape ().end () - 2 };
198198 if (A_bshape != B_bshape) {
199199 std::ostringstream msg;
200- msg << " [matmul] Got matrices with incorrectly broadcasted shapes: "
201- << " A " << a.shape () << " , B " << b.shape () << " ." ;
200+ msg << " [matmul] Got matrices with incorrectly broadcasted shapes: " << " A "
201+ << a.shape () << " , B " << b.shape () << " ." ;
202202 throw std::runtime_error (msg.str ());
203203 }
204204
@@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
227227 std::vector<int > C_bshape{c.shape ().begin (), c.shape ().end () - 2 };
228228 if (A_bshape != B_bshape || A_bshape != C_bshape) {
229229 std::ostringstream msg;
230- msg << " [addmm] Got matrices with incorrectly broadcasted shapes: "
231- << " A " << a.shape () << " , B " << b.shape () << " , B " << c.shape ()
232- << " ." ;
230+ msg << " [addmm] Got matrices with incorrectly broadcasted shapes: " << " A "
231+ << a.shape () << " , B " << b.shape () << " , B " << c.shape () << " ." ;
233232 throw std::runtime_error (msg.str ());
234233 }
235234
@@ -332,8 +331,8 @@ void steel_matmul(
332331 << (transpose_b ? ' t' : ' n' ) << " _" << type_to_name (a) << " _"
333332 << type_to_name (C_split) << " _bm" << bm << " _bn" << bn << " _bk" << bk
334333 << " _wm" << wm << " _wn" << wn << " _MN_"
335- << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned"
336- << " _K_ " << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
334+ << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned" << " _K_ "
335+ << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
337336
338337 // Encode and dispatch gemm kernel
339338 auto & compute_encoder = d.get_command_encoder (s.index );
@@ -422,8 +421,8 @@ void steel_matmul(
422421 << (transpose_b ? ' t' : ' n' ) << " _" << type_to_name (a) << " _"
423422 << type_to_name (out) << " _bm" << bm << " _bn" << bn << " _bk" << bk
424423 << " _wm" << wm << " _wn" << wn << " _MN_"
425- << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned"
426- << " _K_ " << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
424+ << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned" << " _K_ "
425+ << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
427426
428427 // Encode and dispatch kernel
429428 auto & compute_encoder = d.get_command_encoder (s.index );
@@ -903,8 +902,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
903902 << (transpose_b ? ' t' : ' n' ) << " _" << type_to_name (a) << " _"
904903 << type_to_name (C_split) << " _bm" << bm << " _bn" << bn << " _bk" << bk
905904 << " _wm" << wm << " _wn" << wn << " _MN_"
906- << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned"
907- << " _K_ " << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
905+ << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned" << " _K_ "
906+ << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" ;
908907
909908 // Encode and dispatch gemm kernel
910909 auto & compute_encoder = d.get_command_encoder (s.index );
@@ -992,8 +991,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
992991 << (transpose_b ? ' t' : ' n' ) << " _" << type_to_name (a) << " _"
993992 << type_to_name (out) << " _bm" << bm << " _bn" << bn << " _bk" << bk
994993 << " _wm" << wm << " _wn" << wn << " _MN_"
995- << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned"
996- << " _K_ " << ((K % bk == 0 ) ? " t" : " n" ) << " aligned"
994+ << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned" << " _K_ "
995+ << ((K % bk == 0 ) ? " t" : " n" ) << " aligned"
997996 << ((alpha_ == 1 . && beta_ == 1 .) ? " _add" : " _axpby" );
998997
999998 // Encode and dispatch kernel
0 commit comments