@@ -567,8 +567,9 @@ array scaled_dot_product_attention(
567567 const array& keys,
568568 const array& values,
569569 const float scale,
570- const std::variant<std::monostate, std::string, array>& mask /* = {}*/ ,
571- StreamOrDevice s) {
570+ const std::string& mask_mode /* = "" */ ,
571+ const std::vector<array>& mask_arrs /* = {} */ ,
572+ StreamOrDevice s /* = {}*/ ) {
572573 for (const auto & tensor : {queries, keys, values}) {
573574 if (tensor.ndim () != 4 ) {
574575 std::ostringstream msg;
@@ -577,29 +578,49 @@ array scaled_dot_product_attention(
577578 throw std::invalid_argument (msg.str ());
578579 }
579580 }
581+ // Check valid mask
582+ if (mask_mode != " " && mask_mode != " causal" && mask_mode != " array" ) {
583+ std::ostringstream msg;
584+ msg << " [scaled_dot_product_attention] Invalid mask_mode " << mask_mode
585+ << " . mask_mode must be 'causal', 'array' or ''." ;
586+ throw std::invalid_argument (msg.str ());
587+ }
580588
581589 bool do_causal = false ;
582- bool has_mask = !std::holds_alternative<std::monostate>(mask);
583- bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
584- bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
590+ bool has_mask = false ;
591+ bool has_arr_mask = false ;
585592 bool has_bool_mask = false ;
586593
587- if (has_str_mask) {
588- if (std::get<std::string>(mask) != " causal" ) {
594+ if (mask_mode == " causal" ) {
595+ has_mask = true ;
596+ do_causal = true ;
597+
598+ if (!mask_arrs.empty ()) {
589599 std::ostringstream msg;
590- msg << " [scaled_dot_product_attention] invalid mask option ' "
591- << std::get<std::string>(mask) << " '. Must be 'causal', or an array ." ;
600+ msg << " [scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
601+ << " 'casusal'. No array masks supported ." ;
592602 throw std::invalid_argument (msg.str ());
593- } else {
594- do_causal = true ;
595603 }
596604 }
597605
598- if (has_arr_mask && (std::get<array>(mask)).ndim () > 4 ) {
606+ if (mask_mode == " array" || (mask_mode == " " && !mask_arrs.empty ())) {
607+ if (mask_arrs.size () != 1 ) {
608+ std::ostringstream msg;
609+ msg << " [scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
610+ << " '" << mask_mode << " '. Only 1 mask array is supported, got "
611+ << mask_arrs.size () << " arrays." ;
612+ throw std::invalid_argument (msg.str ());
613+ }
614+
615+ has_mask = true ;
616+ has_arr_mask = true ;
617+ has_bool_mask = mask_arrs[0 ].dtype () == bool_;
618+ }
619+
620+ if (has_arr_mask && (mask_arrs[0 ]).ndim () > 4 ) {
599621 std::ostringstream msg;
600622 msg << " [scaled_dot_product_attention] the mask with shape "
601- << (std::get<array>(mask)).shape ()
602- << " expected to have at most rank 4" ;
623+ << mask_arrs[0 ].shape () << " expected to have at most rank 4." ;
603624 throw std::invalid_argument (msg.str ());
604625 }
605626
@@ -736,7 +757,7 @@ array scaled_dot_product_attention(
736757 std::vector<array> inputs = {q, k, v};
737758 if (has_arr_mask) {
738759 // Check type
739- auto mask_arr = std::get<array>(mask) ;
760+ auto mask_arr = mask_arrs[ 0 ] ;
740761 has_bool_mask = mask_arr.dtype () == bool_;
741762 if (promote_types (mask_arr.dtype (), final_type) != final_type) {
742763 std::ostringstream msg;
0 commit comments