Skip to content

Commit 243c7ff

Browse files
committed
v1.12.1
Patch release - Added a dummy padding mask when the actual seq is not a multiple of tile size in the bprop
1 parent f937055 commit 243c7ff

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.23)
22

3-
project(cudnn_frontend VERSION 1.12.0)
3+
project(cudnn_frontend VERSION 1.12.1)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

include/cudnn_frontend/node/scaled_dot_product_flash_attention.h

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,18 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
849849
std::shared_ptr<Tensor_attributes> alibi_slopes;
850850
int64_t alibi_slopes_size = 0;
851851

852+
mutable bool has_workaround_padding_mask = false; // Will be edited in pre_validate_node()
853+
mutable int32_t s_q_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node()
854+
mutable int32_t s_kv_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node()
855+
mutable std::shared_ptr<Tensor_attributes>
856+
workaround_padding_mask_seq_len_q; // Will be edited in pre_validate_node()
857+
mutable std::shared_ptr<Tensor_attributes>
858+
workaround_padding_mask_seq_len_kv; // Will be edited in pre_validate_node()
859+
mutable int64_t batch_size_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node()
860+
861+
852862
public:
853-
SDPA_backward_attributes attributes;
863+
mutable SDPA_backward_attributes attributes; // Will be edited in pre_validate_node() for workaround padding mask
854864

855865
SDPABackwardNode(SDPA_backward_attributes&& attributes_, detail::Context const& context)
856866
: NodeCRTP(context), attributes(std::move(attributes_)) {}
@@ -977,6 +987,20 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
977987
error_code_t::GRAPH_NOT_SUPPORTED,
978988
"Bias mask data type cannot be boolean");
979989

990+
if (s_q % 128 != 0 && attributes.padding_mask == false && is_ragged == false) {
991+
CUDNN_FE_LOG_LABEL_ENDL("INFO: Workaround padding mask is enabled for s_q % 128 != 0 and use_padding_mask == false and is_ragged == false");
992+
has_workaround_padding_mask = true;
993+
batch_size_for_workaround_padding_mask = attributes.inputs.at(input_names::Q)->get_dim()[0];
994+
s_q_for_workaround_padding_mask = s_q;
995+
s_kv_for_workaround_padding_mask = s_kv;
996+
workaround_padding_mask_seq_len_q = std::make_shared<Tensor_attributes>();
997+
workaround_padding_mask_seq_len_q->set_name("workaround_padding_mask_seq_len_q").set_dim({batch_size_for_workaround_padding_mask,1,1,1}).set_stride({1,1,1,1}).set_data_type(DataType_t::INT32);
998+
workaround_padding_mask_seq_len_kv = std::make_shared<Tensor_attributes>();
999+
workaround_padding_mask_seq_len_kv->set_name("workaround_padding_mask_seq_len_kv").set_dim({batch_size_for_workaround_padding_mask,1,1,1}).set_stride({1,1,1,1}).set_data_type(DataType_t::INT32);
1000+
attributes.set_padding_mask(true);
1001+
attributes.set_seq_len_q(workaround_padding_mask_seq_len_q).set_seq_len_kv(workaround_padding_mask_seq_len_kv);
1002+
}
1003+
9801004
// validate options for padding mask
9811005
auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q);
9821006
bool const has_seq_len_q = (seq_len_q != attributes.inputs.end()) && (seq_len_q->second != nullptr);
@@ -1694,6 +1718,10 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
16941718
size += dV_fullhead_size;
16951719
size += softmax_sum_size;
16961720

1721+
if (has_workaround_padding_mask) {
1722+
size += batch_size_for_workaround_padding_mask * sizeof(int32_t) * 2;
1723+
}
1724+
16971725
return size;
16981726
}
16991727

@@ -1737,6 +1765,34 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
17371765
offset = offset + softmax_sum_size;
17381766
}
17391767

1768+
if (has_workaround_padding_mask) {
1769+
CUDNN_FE_LOG_LABEL_ENDL("INFO: Collecting workaround padding mask tensors with batch size "
1770+
<< batch_size_for_workaround_padding_mask << " with UIDs "
1771+
<< workaround_padding_mask_seq_len_q->get_uid() << " and "
1772+
<< workaround_padding_mask_seq_len_kv->get_uid());
1773+
std::vector<int32_t> workaround_padding_mask_seq_len_q_vec(batch_size_for_workaround_padding_mask,
1774+
s_q_for_workaround_padding_mask);
1775+
std::vector<int32_t> workaround_padding_mask_seq_len_kv_vec(batch_size_for_workaround_padding_mask,
1776+
s_kv_for_workaround_padding_mask);
1777+
1778+
// reinterpret_cast the int32_t vector data to float vector for workspace_modifications
1779+
std::vector<float> workaround_padding_mask_seq_len_q_vec_float(
1780+
reinterpret_cast<float*>(workaround_padding_mask_seq_len_q_vec.data()),
1781+
reinterpret_cast<float*>(workaround_padding_mask_seq_len_q_vec.data()) +
1782+
batch_size_for_workaround_padding_mask);
1783+
std::vector<float> workaround_padding_mask_seq_len_kv_vec_float(
1784+
reinterpret_cast<float*>(workaround_padding_mask_seq_len_kv_vec.data()),
1785+
reinterpret_cast<float*>(workaround_padding_mask_seq_len_kv_vec.data()) +
1786+
batch_size_for_workaround_padding_mask);
1787+
1788+
workspace_modifications.emplace(workaround_padding_mask_seq_len_q->get_uid(),
1789+
std::make_tuple(0, offset, workaround_padding_mask_seq_len_q_vec_float));
1790+
offset = offset + batch_size_for_workaround_padding_mask * sizeof(float);
1791+
workspace_modifications.emplace(workaround_padding_mask_seq_len_kv->get_uid(),
1792+
std::make_tuple(0, offset, workaround_padding_mask_seq_len_kv_vec_float));
1793+
offset = offset + batch_size_for_workaround_padding_mask * sizeof(float);
1794+
}
1795+
17401796
return {error_code_t::OK, ""};
17411797
}
17421798

include/cudnn_frontend_version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424

2525
#define CUDNN_FRONTEND_MAJOR_VERSION 1
2626
#define CUDNN_FRONTEND_MINOR_VERSION 12
27-
#define CUDNN_FRONTEND_PATCH_VERSION 0
27+
#define CUDNN_FRONTEND_PATCH_VERSION 1
2828
#define CUDNN_FRONTEND_VERSION \
2929
((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION)

python/cudnn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def is_windows():
4343

4444
from .datatypes import _library_type, _is_torch_tensor
4545

46-
__version__ = "1.12.0"
46+
__version__ = "1.12.1"
4747

4848

4949
def _tensor(

0 commit comments

Comments
 (0)