Skip to content

Commit 3290bfa

Browse files
authored
Add new sdpa function overload (#2035)
* Add new sdpa function overload * Address comments * Remove std::varaint from cpp sdpa function
1 parent 8777fd1 commit 3290bfa

File tree

3 files changed

+71
-17
lines changed

3 files changed

+71
-17
lines changed

mlx/fast.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,9 @@ array scaled_dot_product_attention(
567567
const array& keys,
568568
const array& values,
569569
const float scale,
570-
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
571-
StreamOrDevice s) {
570+
const std::string& mask_mode /* = "" */,
571+
const std::vector<array>& mask_arrs /* = {} */,
572+
StreamOrDevice s /* = {}*/) {
572573
for (const auto& tensor : {queries, keys, values}) {
573574
if (tensor.ndim() != 4) {
574575
std::ostringstream msg;
@@ -577,29 +578,49 @@ array scaled_dot_product_attention(
577578
throw std::invalid_argument(msg.str());
578579
}
579580
}
581+
// Check valid mask
582+
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
583+
std::ostringstream msg;
584+
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
585+
<< ". mask_mode must be 'causal', 'array' or ''.";
586+
throw std::invalid_argument(msg.str());
587+
}
580588

581589
bool do_causal = false;
582-
bool has_mask = !std::holds_alternative<std::monostate>(mask);
583-
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
584-
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
590+
bool has_mask = false;
591+
bool has_arr_mask = false;
585592
bool has_bool_mask = false;
586593

587-
if (has_str_mask) {
588-
if (std::get<std::string>(mask) != "causal") {
594+
if (mask_mode == "causal") {
595+
has_mask = true;
596+
do_causal = true;
597+
598+
if (!mask_arrs.empty()) {
589599
std::ostringstream msg;
590-
msg << "[scaled_dot_product_attention] invalid mask option '"
591-
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
600+
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
601+
<< "'casusal'. No array masks supported.";
592602
throw std::invalid_argument(msg.str());
593-
} else {
594-
do_causal = true;
595603
}
596604
}
597605

598-
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
606+
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
607+
if (mask_arrs.size() != 1) {
608+
std::ostringstream msg;
609+
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
610+
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
611+
<< mask_arrs.size() << "arrays.";
612+
throw std::invalid_argument(msg.str());
613+
}
614+
615+
has_mask = true;
616+
has_arr_mask = true;
617+
has_bool_mask = mask_arrs[0].dtype() == bool_;
618+
}
619+
620+
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
599621
std::ostringstream msg;
600622
msg << "[scaled_dot_product_attention] the mask with shape "
601-
<< (std::get<array>(mask)).shape()
602-
<< " expected to have at most rank 4";
623+
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
603624
throw std::invalid_argument(msg.str());
604625
}
605626

@@ -736,7 +757,7 @@ array scaled_dot_product_attention(
736757
std::vector<array> inputs = {q, k, v};
737758
if (has_arr_mask) {
738759
// Check type
739-
auto mask_arr = std::get<array>(mask);
760+
auto mask_arr = mask_arrs[0];
740761
has_bool_mask = mask_arr.dtype() == bool_;
741762
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
742763
std::ostringstream msg;

mlx/fast.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ array scaled_dot_product_attention(
4848
const array& keys,
4949
const array& values,
5050
const float scale,
51-
const std::variant<std::monostate, std::string, array>& mask = {},
51+
const std::string& mask_mode = "",
52+
const std::vector<array>& mask_arrs = {},
5253
StreamOrDevice s = {});
5354

5455
std::tuple<array, array, array> affine_quantize(

python/src/fast.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,39 @@ void init_fast(nb::module_& parent_module) {
124124

125125
m.def(
126126
"scaled_dot_product_attention",
127-
&mx::fast::scaled_dot_product_attention,
127+
[](const mx::array& queries,
128+
const mx::array& keys,
129+
const mx::array& values,
130+
const float scale,
131+
const std::variant<std::monostate, std::string, mx::array>& mask,
132+
mx::StreamOrDevice s) {
133+
bool has_mask = !std::holds_alternative<std::monostate>(mask);
134+
bool has_str_mask =
135+
has_mask && std::holds_alternative<std::string>(mask);
136+
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
137+
138+
if (has_mask) {
139+
if (has_str_mask) {
140+
auto mask_str = std::get<std::string>(mask);
141+
if (mask_str != "causal") {
142+
std::ostringstream msg;
143+
msg << "[scaled_dot_product_attention] invalid mask option '"
144+
<< mask_str << "'. Must be 'causal', or an array.";
145+
throw std::invalid_argument(msg.str());
146+
}
147+
return mx::fast::scaled_dot_product_attention(
148+
queries, keys, values, scale, mask_str, {}, s);
149+
} else {
150+
auto mask_arr = std::get<mx::array>(mask);
151+
return mx::fast::scaled_dot_product_attention(
152+
queries, keys, values, scale, "", {mask_arr}, s);
153+
}
154+
155+
} else {
156+
return mx::fast::scaled_dot_product_attention(
157+
queries, keys, values, scale, "", {}, s);
158+
}
159+
},
128160
"q"_a,
129161
"k"_a,
130162
"v"_a,

0 commit comments

Comments
 (0)