@@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11221122 // ///////////////////////////////////////////////////////////////////////////
11231123 // Check and collapse batch dimensions
11241124
1125- auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches (a, b);
1125+ bool has_op_mask = inputs.size () > 3 ;
1126+ auto & out_mask = inputs[2 ];
1127+
1128+ std::vector<int > batch_shape{1 };
1129+ int A_batch_str = 0 ;
1130+ int B_batch_str = 0 ;
1131+
1132+ std::vector<size_t > batch_strides;
1133+
1134+ if (out.ndim () > 2 ) {
1135+ auto get_batch_dims = [](const auto & v) {
1136+ return decltype (v){v.begin (), v.end () - 2 };
1137+ };
1138+
1139+ std::vector<int > bshape{out.shape ().begin (), out.shape ().end () - 2 };
1140+ std::vector<std::vector<size_t >> bstrides;
1141+
1142+ for (auto & arr : inputs) {
1143+ bstrides.emplace_back (arr.strides ().begin (), arr.strides ().end () - 2 );
1144+ }
1145+
1146+ auto [bshape_c, bstrides_c] = collapse_contiguous_dims (bshape, bstrides);
1147+ batch_shape = bshape_c;
1148+ A_batch_str = int (bstrides_c[0 ].back ());
1149+ B_batch_str = int (bstrides_c[1 ].back ());
1150+
1151+ for (auto & bstr : bstrides_c) {
1152+ batch_strides.insert (batch_strides.end (), bstr.begin (), bstr.end ());
1153+ }
1154+ } else {
1155+ batch_strides = std::vector<size_t >(inputs.size (), 0 );
1156+ }
11261157
11271158 auto batch_size_out = out.size () / (M * N);
11281159 int matrix_stride_out = M * N;
@@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11421173 << " _wm" << wm << " _wn" << wn << " _MN_"
11431174 << ((M % bm == 0 && N % bn == 0 ) ? " t" : " n" ) << " aligned" << " _K_"
11441175 << ((K % bk == 0 ) ? " t" : " n" ) << " aligned" << " _op_mask_"
1145- << (inputs. size () > 3 ? " T" : " N" );
1176+ << (has_op_mask ? " T" : " N" );
11461177
11471178 // Encode and dispatch kernel
11481179 auto & compute_encoder = d.get_command_encoder (s.index );
@@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11661197 /* const int ldd = */ N,
11671198 /* const int tiles_n = */ tn,
11681199 /* const int tiles_m = */ tm,
1169- /* const int batch_stride_a = */ int (A_batch_stride. back ()) ,
1170- /* const int batch_stride_b = */ int (B_batch_stride. back ()) ,
1200+ /* const int batch_stride_a = */ A_batch_str ,
1201+ /* const int batch_stride_b = */ B_batch_str ,
11711202 /* const int batch_stride_d = */ matrix_stride_out,
11721203 /* const int swizzle_log = */ swizzle_log,
11731204 /* const int gemm_k_iterations_aligned = */ (K / bk),
@@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11811212 MTL::Size group_dims = MTL::Size (32 , wn, wm);
11821213 MTL::Size grid_dims = MTL::Size (tn, tm, batch_size_out);
11831214
1184- std::vector<size_t > batch_strides = A_batch_stride;
1185- batch_strides.insert (
1186- batch_strides.end (), B_batch_stride.begin (), B_batch_stride.end ());
1187-
11881215 std::vector<int > mask_strides;
1189-
1190- auto & out_mask = inputs[2 ];
11911216 mask_strides.push_back (*(out_mask.strides ().end () - 1 ));
11921217 mask_strides.push_back (*(out_mask.strides ().end () - 2 ));
11931218
1194- batch_strides.insert (
1195- batch_strides.end (),
1196- out_mask.strides ().begin (),
1197- out_mask.strides ().end () - 2 );
1198-
1199- if (inputs.size () > 3 ) {
1219+ if (has_op_mask) {
12001220 auto & lhs_mask = inputs[3 ];
12011221 mask_strides.push_back (*(lhs_mask.strides ().end () - 1 ));
12021222 mask_strides.push_back (*(lhs_mask.strides ().end () - 2 ));
12031223
1204- batch_strides.insert (
1205- batch_strides.end (),
1206- lhs_mask.strides ().begin (),
1207- lhs_mask.strides ().end () - 2 );
1208-
12091224 compute_encoder.set_input_array (lhs_mask, 11 );
12101225
12111226 auto & rhs_mask = inputs[4 ];
12121227 mask_strides.push_back (*(rhs_mask.strides ().end () - 1 ));
12131228 mask_strides.push_back (*(rhs_mask.strides ().end () - 2 ));
12141229
1215- batch_strides.insert (
1216- batch_strides.end (),
1217- rhs_mask.strides ().begin (),
1218- rhs_mask.strides ().end () - 2 );
1219-
12201230 compute_encoder.set_input_array (rhs_mask, 12 );
12211231 }
12221232
0 commit comments