Skip to content

Commit 85c8a91

Browse files
authored
Fix mask broadcasting bug and add relevant test (#1003)
1 parent 581b699 commit 85c8a91

File tree

3 files changed

+67
-35
lines changed

3 files changed

+67
-35
lines changed

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ template <typename T,
6565
lhs_mask += batch_offsets.x;
6666
rhs_mask += batch_offsets.y;
6767
}
68+
} else {
69+
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
70+
if(has_operand_mask) {
71+
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
72+
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
73+
}
6874
}
6975

7076
// Adjust for batch

mlx/backend/metal/matmul.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11221122
/////////////////////////////////////////////////////////////////////////////
11231123
// Check and collapse batch dimensions
11241124

1125-
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
1125+
bool has_op_mask = inputs.size() > 3;
1126+
auto& out_mask = inputs[2];
1127+
1128+
std::vector<int> batch_shape{1};
1129+
int A_batch_str = 0;
1130+
int B_batch_str = 0;
1131+
1132+
std::vector<size_t> batch_strides;
1133+
1134+
if (out.ndim() > 2) {
1135+
auto get_batch_dims = [](const auto& v) {
1136+
return decltype(v){v.begin(), v.end() - 2};
1137+
};
1138+
1139+
std::vector<int> bshape{out.shape().begin(), out.shape().end() - 2};
1140+
std::vector<std::vector<size_t>> bstrides;
1141+
1142+
for (auto& arr : inputs) {
1143+
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
1144+
}
1145+
1146+
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
1147+
batch_shape = bshape_c;
1148+
A_batch_str = int(bstrides_c[0].back());
1149+
B_batch_str = int(bstrides_c[1].back());
1150+
1151+
for (auto& bstr : bstrides_c) {
1152+
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
1153+
}
1154+
} else {
1155+
batch_strides = std::vector<size_t>(inputs.size(), 0);
1156+
}
11261157

11271158
auto batch_size_out = out.size() / (M * N);
11281159
int matrix_stride_out = M * N;
@@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11421173
<< "_wm" << wm << "_wn" << wn << "_MN_"
11431174
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
11441175
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
1145-
<< (inputs.size() > 3 ? "T" : "N");
1176+
<< (has_op_mask ? "T" : "N");
11461177

11471178
// Encode and dispatch kernel
11481179
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11661197
/* const int ldd = */ N,
11671198
/* const int tiles_n = */ tn,
11681199
/* const int tiles_m = */ tm,
1169-
/* const int batch_stride_a = */ int(A_batch_stride.back()),
1170-
/* const int batch_stride_b = */ int(B_batch_stride.back()),
1200+
/* const int batch_stride_a = */ A_batch_str,
1201+
/* const int batch_stride_b = */ B_batch_str,
11711202
/* const int batch_stride_d = */ matrix_stride_out,
11721203
/* const int swizzle_log = */ swizzle_log,
11731204
/* const int gemm_k_iterations_aligned = */ (K / bk),
@@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
11811212
MTL::Size group_dims = MTL::Size(32, wn, wm);
11821213
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
11831214

1184-
std::vector<size_t> batch_strides = A_batch_stride;
1185-
batch_strides.insert(
1186-
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
1187-
11881215
std::vector<int> mask_strides;
1189-
1190-
auto& out_mask = inputs[2];
11911216
mask_strides.push_back(*(out_mask.strides().end() - 1));
11921217
mask_strides.push_back(*(out_mask.strides().end() - 2));
11931218

1194-
batch_strides.insert(
1195-
batch_strides.end(),
1196-
out_mask.strides().begin(),
1197-
out_mask.strides().end() - 2);
1198-
1199-
if (inputs.size() > 3) {
1219+
if (has_op_mask) {
12001220
auto& lhs_mask = inputs[3];
12011221
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
12021222
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
12031223

1204-
batch_strides.insert(
1205-
batch_strides.end(),
1206-
lhs_mask.strides().begin(),
1207-
lhs_mask.strides().end() - 2);
1208-
12091224
compute_encoder.set_input_array(lhs_mask, 11);
12101225

12111226
auto& rhs_mask = inputs[4];
12121227
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
12131228
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
12141229

1215-
batch_strides.insert(
1216-
batch_strides.end(),
1217-
rhs_mask.strides().begin(),
1218-
rhs_mask.strides().end() - 2);
1219-
12201230
compute_encoder.set_input_array(rhs_mask, 12);
12211231
}
12221232

python/tests/test_blas.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -717,36 +717,49 @@ def expand_mask(mask, block_size, Y, X):
717717
out = out * out_mask
718718
return out
719719

720-
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
720+
def test_shape(
721+
M,
722+
N,
723+
K,
724+
block_size,
725+
transpose=False,
726+
np_dtype=np.float32,
727+
batch_A=(),
728+
batch_B=(),
729+
):
721730
with self.subTest(
722731
M=M,
723732
N=N,
724733
K=K,
725734
block_size=block_size,
726735
np_dtype=np_dtype,
727736
transpose=transpose,
737+
batch_A=batch_A,
738+
batch_B=batch_B,
728739
):
729740
tm = (M + block_size - 1) // block_size
730741
tn = (N + block_size - 1) // block_size
731742
tk = (K + block_size - 1) // block_size
732743

733-
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
734-
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
744+
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
745+
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
746+
747+
batch_out = np.broadcast_shapes(batch_A, batch_B)
735748

736-
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
737-
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
738-
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
749+
a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0
750+
b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0
751+
out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0
739752

740753
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
741754
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
742755
)
743756

744757
if transpose:
745-
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
758+
b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype)
746759
b_mx = mx.array(b_np)
747760

748-
b_np = b_np.T
749-
b_mx = b_mx.T
761+
b_np = np.swapaxes(b_np, -2, -1)
762+
b_mx = mx.swapaxes(b_mx, -2, -1)
750763

751764
out_np = np_block_masked_mm(
752765
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
@@ -779,6 +792,9 @@ def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
779792
test_shape(M, N, K, block_size, transpose=False)
780793
test_shape(M, N, K, block_size, transpose=True)
781794

795+
# Test broadcasting
796+
test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2))
797+
782798
# Test gemv
783799
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
784800
b_np = np.random.normal(size=(64,)).astype(np.float32)

0 commit comments

Comments
 (0)