@@ -849,8 +849,18 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
849849 std::shared_ptr<Tensor_attributes> alibi_slopes;
850850 int64_t alibi_slopes_size = 0 ;
851851
852+ mutable bool has_workaround_padding_mask = false ; // Will be edited in pre_validate_node()
853+ mutable int32_t s_q_for_workaround_padding_mask = 0 ; // Will be edited in pre_validate_node()
854+ mutable int32_t s_kv_for_workaround_padding_mask = 0 ; // Will be edited in pre_validate_node()
855+ mutable std::shared_ptr<Tensor_attributes>
856+ workaround_padding_mask_seq_len_q; // Will be edited in pre_validate_node()
857+ mutable std::shared_ptr<Tensor_attributes>
858+ workaround_padding_mask_seq_len_kv; // Will be edited in pre_validate_node()
859+ mutable int64_t batch_size_for_workaround_padding_mask = 0 ; // Will be edited in pre_validate_node()
860+
861+
852862 public:
853- SDPA_backward_attributes attributes;
863+ mutable SDPA_backward_attributes attributes; // Will be edited in pre_validate_node() for workaround padding mask
854864
855865 SDPABackwardNode (SDPA_backward_attributes&& attributes_, detail::Context const & context)
856866 : NodeCRTP(context), attributes(std::move(attributes_)) {}
@@ -977,6 +987,20 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
977987 error_code_t ::GRAPH_NOT_SUPPORTED,
978988 " Bias mask data type cannot be boolean" );
979989
990+ if (s_q % 128 != 0 && attributes.padding_mask == false && is_ragged == false ) {
991+ CUDNN_FE_LOG_LABEL_ENDL (" INFO: Workaround padding mask is enabled for s_q % 128 != 0 and use_padding_mask == false and is_ragged == false" );
992+ has_workaround_padding_mask = true ;
993+ batch_size_for_workaround_padding_mask = attributes.inputs .at (input_names::Q)->get_dim ()[0 ];
994+ s_q_for_workaround_padding_mask = s_q;
995+ s_kv_for_workaround_padding_mask = s_kv;
996+ workaround_padding_mask_seq_len_q = std::make_shared<Tensor_attributes>();
997+ workaround_padding_mask_seq_len_q->set_name (" workaround_padding_mask_seq_len_q" ).set_dim ({batch_size_for_workaround_padding_mask,1 ,1 ,1 }).set_stride ({1 ,1 ,1 ,1 }).set_data_type (DataType_t::INT32);
998+ workaround_padding_mask_seq_len_kv = std::make_shared<Tensor_attributes>();
999+ workaround_padding_mask_seq_len_kv->set_name (" workaround_padding_mask_seq_len_kv" ).set_dim ({batch_size_for_workaround_padding_mask,1 ,1 ,1 }).set_stride ({1 ,1 ,1 ,1 }).set_data_type (DataType_t::INT32);
1000+ attributes.set_padding_mask (true );
1001+ attributes.set_seq_len_q (workaround_padding_mask_seq_len_q).set_seq_len_kv (workaround_padding_mask_seq_len_kv);
1002+ }
1003+
9801004 // validate options for padding mask
9811005 auto const & seq_len_q = attributes.inputs .find (input_names::SEQ_LEN_Q);
9821006 bool const has_seq_len_q = (seq_len_q != attributes.inputs .end ()) && (seq_len_q->second != nullptr );
@@ -1694,6 +1718,10 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
16941718 size += dV_fullhead_size;
16951719 size += softmax_sum_size;
16961720
1721+ if (has_workaround_padding_mask) {
1722+ size += batch_size_for_workaround_padding_mask * sizeof (int32_t ) * 2 ;
1723+ }
1724+
16971725 return size;
16981726 }
16991727
@@ -1737,6 +1765,34 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
17371765 offset = offset + softmax_sum_size;
17381766 }
17391767
1768+ if (has_workaround_padding_mask) {
1769+ CUDNN_FE_LOG_LABEL_ENDL (" INFO: Collecting workaround padding mask tensors with batch size "
1770+ << batch_size_for_workaround_padding_mask << " with UIDs "
1771+ << workaround_padding_mask_seq_len_q->get_uid () << " and "
1772+ << workaround_padding_mask_seq_len_kv->get_uid ());
1773+ std::vector<int32_t > workaround_padding_mask_seq_len_q_vec (batch_size_for_workaround_padding_mask,
1774+ s_q_for_workaround_padding_mask);
1775+ std::vector<int32_t > workaround_padding_mask_seq_len_kv_vec (batch_size_for_workaround_padding_mask,
1776+ s_kv_for_workaround_padding_mask);
1777+
1778+ // reinterpret_cast the int32_t vector data to float vector for workspace_modifications
1779+ std::vector<float > workaround_padding_mask_seq_len_q_vec_float (
1780+ reinterpret_cast <float *>(workaround_padding_mask_seq_len_q_vec.data ()),
1781+ reinterpret_cast <float *>(workaround_padding_mask_seq_len_q_vec.data ()) +
1782+ batch_size_for_workaround_padding_mask);
1783+ std::vector<float > workaround_padding_mask_seq_len_kv_vec_float (
1784+ reinterpret_cast <float *>(workaround_padding_mask_seq_len_kv_vec.data ()),
1785+ reinterpret_cast <float *>(workaround_padding_mask_seq_len_kv_vec.data ()) +
1786+ batch_size_for_workaround_padding_mask);
1787+
1788+ workspace_modifications.emplace (workaround_padding_mask_seq_len_q->get_uid (),
1789+ std::make_tuple (0 , offset, workaround_padding_mask_seq_len_q_vec_float));
1790+ offset = offset + batch_size_for_workaround_padding_mask * sizeof (float );
1791+ workspace_modifications.emplace (workaround_padding_mask_seq_len_kv->get_uid (),
1792+ std::make_tuple (0 , offset, workaround_padding_mask_seq_len_kv_vec_float));
1793+ offset = offset + batch_size_for_workaround_padding_mask * sizeof (float );
1794+ }
1795+
17401796 return {error_code_t ::OK, " " };
17411797 }
17421798
0 commit comments