Skip to content

Commit eab2685

Browse files
authored
Float mask update (#1152)
* Float mask update * Update CPU impl
1 parent 50dfb66 commit eab2685

File tree

8 files changed

+709
-249
lines changed

8 files changed

+709
-249
lines changed

mlx/backend/common/masked_mm.cpp

Lines changed: 76 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,25 @@ namespace mlx::core {
1717

1818
namespace {
1919

20-
template <typename T>
20+
template <typename T, typename mask_t>
2121
inline void mask_matrix(
2222
T* data,
23-
const bool* mask,
23+
const mask_t* mask,
2424
int block_size,
2525
const int X,
2626
const int Y,
2727
const size_t X_data_str,
2828
const size_t Y_data_str,
2929
const size_t X_mask_str,
30-
const size_t Y_mask_str) {
30+
const size_t Y_mask_str,
31+
const size_t mask_offset) {
3132
int tX = (X + block_size - 1) / block_size;
3233
int tY = (Y + block_size - 1) / block_size;
3334

3435
for (int i = 0; i < tX; i++) {
3536
for (int j = 0; j < tY; j++) {
36-
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
37-
if (!do_mask) {
37+
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
38+
if (do_mask != 1) {
3839
int loc_x = i * block_size;
3940
int loc_y = j * block_size;
4041
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
@@ -43,7 +44,11 @@ inline void mask_matrix(
4344
int size_y = std::min(block_size, Y - loc_y);
4445
for (int ii = 0; ii < size_x; ii++) {
4546
for (int jj = 0; jj < size_y; jj++) {
46-
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
47+
if constexpr (std::is_same_v<mask_t, bool>) {
48+
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
49+
} else {
50+
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
51+
}
4752
}
4853
}
4954
}
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
6267

6368
auto& a_pre = inputs[0];
6469
auto& b_pre = inputs[1];
65-
auto& out_mask = inputs[2];
6670

67-
auto check_transpose = [](const array& arr, bool do_copy) {
68-
auto stx = arr.strides()[arr.ndim() - 2];
69-
auto sty = arr.strides()[arr.ndim() - 1];
70-
if (stx == arr.shape(-1) && sty == 1) {
71-
if (do_copy) {
72-
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
73-
copy(arr, arr_copy, CopyType::Vector);
74-
return std::make_tuple(false, stx, arr_copy);
75-
}
76-
return std::make_tuple(false, stx, arr);
77-
} else if (stx == 1 && sty == arr.shape(-2)) {
78-
if (do_copy) {
79-
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
80-
copy(arr, arr_copy, CopyType::Vector);
81-
return std::make_tuple(true, sty, arr_copy);
82-
}
83-
return std::make_tuple(true, sty, arr);
84-
} else {
85-
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
86-
copy(arr, arr_copy, CopyType::General);
87-
size_t stx = arr.shape(-1);
88-
return std::make_tuple(false, stx, arr_copy);
89-
}
90-
};
71+
auto check_transpose =
72+
[](const array& arr, bool do_copy, bool expand_all = false) {
73+
auto stx = arr.strides()[arr.ndim() - 2];
74+
auto sty = arr.strides()[arr.ndim() - 1];
75+
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
76+
if (do_copy) {
77+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
78+
copy(arr, arr_copy, CopyType::Vector);
79+
return std::make_tuple(false, stx, arr_copy);
80+
}
81+
return std::make_tuple(false, stx, arr);
82+
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
83+
if (do_copy) {
84+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
85+
copy(arr, arr_copy, CopyType::Vector);
86+
return std::make_tuple(true, sty, arr_copy);
87+
}
88+
return std::make_tuple(true, sty, arr);
89+
} else {
90+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
91+
copy(arr, arr_copy, CopyType::General);
92+
size_t stx = arr.shape(-1);
93+
return std::make_tuple(false, stx, arr_copy);
94+
}
95+
};
9196

9297
bool has_op_mask = inputs.size() > 3;
93-
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
94-
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
98+
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
99+
auto [a_transposed, lda, a] =
100+
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
101+
auto [b_transposed, ldb, b] =
102+
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
95103

96104
size_t M = a.shape(-2);
97105
size_t N = b.shape(-1);
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
114122
int Y,
115123
size_t X_data_str,
116124
size_t Y_data_str) {
117-
const bool* mask_ptr = mask.data<bool>() +
118-
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
119-
mask.shape(),
120-
mask.strides());
125+
size_t mask_offset = elem_to_loc(
126+
mask.shape(-1) * mask.shape(-2) * batch_idx,
127+
mask.shape(),
128+
mask.strides());
121129

122130
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
123131
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
124132

125-
return mask_matrix(
126-
data,
127-
mask_ptr,
128-
block_size,
129-
X,
130-
Y,
131-
X_data_str,
132-
Y_data_str,
133-
X_mask_str,
134-
Y_mask_str);
133+
if (mask.dtype() == bool_) {
134+
return mask_matrix(
135+
data,
136+
mask.data<bool>(),
137+
block_size,
138+
X,
139+
Y,
140+
X_data_str,
141+
Y_data_str,
142+
X_mask_str,
143+
Y_mask_str,
144+
mask_offset);
145+
} else {
146+
return mask_matrix(
147+
data,
148+
mask.data<float>(),
149+
block_size,
150+
X,
151+
Y,
152+
X_data_str,
153+
Y_data_str,
154+
X_mask_str,
155+
Y_mask_str,
156+
mask_offset);
157+
}
135158
};
136159

137-
for (int i = 0; i < (a.size() / (M * K)); ++i) {
160+
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
138161
// Adjust pointer
139162
float* ai =
140163
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
144167

145168
// Zero out blocks in a and b if needed
146169
if (has_op_mask) {
147-
auto& a_mask = inputs[3];
170+
auto& a_mask = inputs[inputs.size() - 2];
148171
mask_array(
149172
a_mask,
150173
ai,
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
155178
a_transposed ? 1 : lda,
156179
a_transposed ? lda : 1);
157180

158-
auto& b_mask = inputs[4];
181+
auto& b_mask = inputs[inputs.size() - 1];
159182
mask_array(
160183
b_mask,
161184
bi,
@@ -186,7 +209,9 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
186209
);
187210

188211
// Zero out blocks in out
189-
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
212+
if (has_out_mask) {
213+
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
214+
}
190215
}
191216
}
192217

0 commit comments

Comments
 (0)