@@ -17,24 +17,25 @@ namespace mlx::core {
1717
1818namespace {
1919
20- template <typename T>
20+ template <typename T, typename mask_t >
2121inline void mask_matrix (
2222 T* data,
23- const bool * mask,
23+ const mask_t * mask,
2424 int block_size,
2525 const int X,
2626 const int Y,
2727 const size_t X_data_str,
2828 const size_t Y_data_str,
2929 const size_t X_mask_str,
30- const size_t Y_mask_str) {
30+ const size_t Y_mask_str,
31+ const size_t mask_offset) {
3132 int tX = (X + block_size - 1 ) / block_size;
3233 int tY = (Y + block_size - 1 ) / block_size;
3334
3435 for (int i = 0 ; i < tX; i++) {
3536 for (int j = 0 ; j < tY; j++) {
36- bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
37- if (! do_mask) {
37+ mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
38+ if (do_mask != 1 ) {
3839 int loc_x = i * block_size;
3940 int loc_y = j * block_size;
4041 T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
@@ -43,7 +44,11 @@ inline void mask_matrix(
4344 int size_y = std::min (block_size, Y - loc_y);
4445 for (int ii = 0 ; ii < size_x; ii++) {
4546 for (int jj = 0 ; jj < size_y; jj++) {
46- data_block[ii * X_data_str + jj * Y_data_str] = T (0 .);
47+ if constexpr (std::is_same_v<mask_t , bool >) {
48+ data_block[ii * X_data_str + jj * Y_data_str] = T (0 .);
49+ } else {
50+ data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
51+ }
4752 }
4853 }
4954 }
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
6267
6368 auto & a_pre = inputs[0 ];
6469 auto & b_pre = inputs[1 ];
65- auto & out_mask = inputs[2 ];
6670
67- auto check_transpose = [](const array& arr, bool do_copy) {
68- auto stx = arr.strides ()[arr.ndim () - 2 ];
69- auto sty = arr.strides ()[arr.ndim () - 1 ];
70- if (stx == arr.shape (-1 ) && sty == 1 ) {
71- if (do_copy) {
72- array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
73- copy (arr, arr_copy, CopyType::Vector);
74- return std::make_tuple (false , stx, arr_copy);
75- }
76- return std::make_tuple (false , stx, arr);
77- } else if (stx == 1 && sty == arr.shape (-2 )) {
78- if (do_copy) {
79- array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
80- copy (arr, arr_copy, CopyType::Vector);
81- return std::make_tuple (true , sty, arr_copy);
82- }
83- return std::make_tuple (true , sty, arr);
84- } else {
85- array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
86- copy (arr, arr_copy, CopyType::General);
87- size_t stx = arr.shape (-1 );
88- return std::make_tuple (false , stx, arr_copy);
89- }
90- };
71+ auto check_transpose =
72+ [](const array& arr, bool do_copy, bool expand_all = false ) {
73+ auto stx = arr.strides ()[arr.ndim () - 2 ];
74+ auto sty = arr.strides ()[arr.ndim () - 1 ];
75+ if (!expand_all && stx == arr.shape (-1 ) && sty == 1 ) {
76+ if (do_copy) {
77+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
78+ copy (arr, arr_copy, CopyType::Vector);
79+ return std::make_tuple (false , stx, arr_copy);
80+ }
81+ return std::make_tuple (false , stx, arr);
82+ } else if (!expand_all && stx == 1 && sty == arr.shape (-2 )) {
83+ if (do_copy) {
84+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
85+ copy (arr, arr_copy, CopyType::Vector);
86+ return std::make_tuple (true , sty, arr_copy);
87+ }
88+ return std::make_tuple (true , sty, arr);
89+ } else {
90+ array arr_copy (arr.shape (), arr.dtype (), nullptr , {});
91+ copy (arr, arr_copy, CopyType::General);
92+ size_t stx = arr.shape (-1 );
93+ return std::make_tuple (false , stx, arr_copy);
94+ }
95+ };
9196
9297 bool has_op_mask = inputs.size () > 3 ;
93- auto [a_transposed, lda, a] = check_transpose (a_pre, has_op_mask);
94- auto [b_transposed, ldb, b] = check_transpose (b_pre, has_op_mask);
98+ bool has_out_mask = inputs.size () == 3 || inputs.size () == 5 ;
99+ auto [a_transposed, lda, a] =
100+ check_transpose (a_pre, has_op_mask, inputs.back ().dtype () != bool_);
101+ auto [b_transposed, ldb, b] =
102+ check_transpose (b_pre, has_op_mask, inputs.back ().dtype () != bool_);
95103
96104 size_t M = a.shape (-2 );
97105 size_t N = b.shape (-1 );
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
114122 int Y,
115123 size_t X_data_str,
116124 size_t Y_data_str) {
117- const bool * mask_ptr = mask. data < bool >() +
118- elem_to_loc ( mask.shape (-1 ) * mask.shape (-2 ) * batch_idx,
119- mask.shape (),
120- mask.strides ());
125+ size_t mask_offset = elem_to_loc (
126+ mask.shape (-1 ) * mask.shape (-2 ) * batch_idx,
127+ mask.shape (),
128+ mask.strides ());
121129
122130 size_t X_mask_str = mask.strides ()[mask.ndim () - 2 ];
123131 size_t Y_mask_str = mask.strides ()[mask.ndim () - 1 ];
124132
125- return mask_matrix (
126- data,
127- mask_ptr,
128- block_size,
129- X,
130- Y,
131- X_data_str,
132- Y_data_str,
133- X_mask_str,
134- Y_mask_str);
133+ if (mask.dtype () == bool_) {
134+ return mask_matrix (
135+ data,
136+ mask.data <bool >(),
137+ block_size,
138+ X,
139+ Y,
140+ X_data_str,
141+ Y_data_str,
142+ X_mask_str,
143+ Y_mask_str,
144+ mask_offset);
145+ } else {
146+ return mask_matrix (
147+ data,
148+ mask.data <float >(),
149+ block_size,
150+ X,
151+ Y,
152+ X_data_str,
153+ Y_data_str,
154+ X_mask_str,
155+ Y_mask_str,
156+ mask_offset);
157+ }
135158 };
136159
137- for (int i = 0 ; i < (a .size () / (M * K )); ++i) {
160+ for (int i = 0 ; i < (out .size () / (M * size_t (N) )); ++i) {
138161 // Adjust pointer
139162 float * ai =
140163 a.data <float >() + elem_to_loc (M * K * i, a.shape (), a.strides ());
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
144167
145168 // Zero out blocks in a and b if needed
146169 if (has_op_mask) {
147- auto & a_mask = inputs[3 ];
170+ auto & a_mask = inputs[inputs. size () - 2 ];
148171 mask_array (
149172 a_mask,
150173 ai,
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
155178 a_transposed ? 1 : lda,
156179 a_transposed ? lda : 1 );
157180
158- auto & b_mask = inputs[4 ];
181+ auto & b_mask = inputs[inputs. size () - 1 ];
159182 mask_array (
160183 b_mask,
161184 bi,
@@ -186,7 +209,9 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
186209 );
187210
188211 // Zero out blocks in out
189- mask_array (out_mask, ci, block_size_, i, M, N, N, 1 );
212+ if (has_out_mask) {
213+ mask_array (inputs[2 ], ci, block_size_, i, M, N, N, 1 );
214+ }
190215 }
191216}
192217
0 commit comments