Skip to content

Commit 576e0ed

Browse files
authored
Allow d > 128 for cudnn version > 9.11.0 for B100 (#145)
1 parent 724f0ec commit 576e0ed

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

include/cudnn_frontend/node/scaled_dot_product_flash_attention.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#pragma once
1+
#pragma once
22

33
#include <cstdlib>
44

@@ -940,6 +940,11 @@ class SDPABackwardNode : public NodeCRTP<SDPABackwardNode> {
940940
RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0),
941941
error_code_t::GRAPH_NOT_SUPPORTED,
942942
"Num hidden_dim shoud be less than or equal to 256 and hidden_dim should be multiple of 8");
943+
} else if (prop.major == 10 && detail::get_backend_version() >= 91100) {
944+
// validate basic dimension requirements
945+
RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk % 8 != 0) || (d_v % 8 != 0),
946+
error_code_t::GRAPH_NOT_SUPPORTED,
947+
"Num hidden_dim shoud be should be multiple of 8");
943948
} else {
944949
// validate basic dimension requirements
945950
RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0),

0 commit comments

Comments
 (0)