Skip to content

Commit ee971b1

Browse files
authored
# cudnn frontend v1.9 release notes (#123)
## New API ### cudnn Flex Attention `SDPA_attributes` and `SDPA_bprop_attributes` now accepts a score_mod function through `set_score_mod` and `set_score_mod_bprop` API. The function accepts a custom chain of pointwise operations which operate on the Attention Score Matrix. Some common functors like causal mask, sliding window mask, soft capping etc. have been added to the headers as reference. More examples of usage have been added in samples for [fprop](fp16_fwd_with_flexible_graphs.cpp) and [bprop](fp16_bwd_with_flexible_graphs.cpp). ### Improvements - Added support for THD format and sliding window mask. - Added support for THD format and Bottom right causal mask. - Added a new parameter called `set_max_total_seq_len_q/set_max_total_seq_len_kv` on the sdpa bprop node. This will help reduce the workspace size required when running with THD format. - Allow creation of serialized json for dgrad, wgrad and resample operations. - Added more diagonstic message when the compiled version of cudnn does not match the run-time version of cudnn. ### Bug fixes - Fixed an issue where log messages unparseable data at the end of messages. - Fixed an issue where while building the python pip wheel would hang. - Fixed natively creating cuda graphs for SDPA with alibi masks. ### New samples - Added a new sample for Layernorm with dynamic shapes and a kernel cache to showcase reduced plan build time when using the kernel cache.
1 parent 936021b commit ee971b1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1831
-585
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.17)
22

3-
project(cudnn_frontend VERSION 1.8.0)
3+
project(cudnn_frontend VERSION 1.9.0)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

docs/operations/Attention.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ set_paged_attention_v_table(std::shared_ptr<Tensor_attributes> value);
175175
SDPA_attributes&
176176
set_paged_attention_max_seq_len_kv(int const value);
177177
178+
SDPA_attributes&
179+
set_score_mod(std::function<Tensor_t(Graph_t, Tensor_t)>);
178180
```
179181

180182
#### Python API:
@@ -307,6 +309,9 @@ set_deterministic_algorithm(bool const value);
307309
308310
SDPA_backward_attributes&
309311
set_compute_data_type(DataType_t const value);
312+
313+
SDPA_backward_attributes&
314+
set_score_mod(std::function<Tensor_t(Graph_t, Tensor_t)>);
310315
```
311316

312317
#### Python API:
@@ -720,3 +725,8 @@ cuDNN layout support for variable sequence length includes (but is not limited t
720725
- Valid tokens are not packed together\
721726
`Q = a0abbb00bb000000`\
722727
Ragged offset is insufficient to represent this. This case is NOT supported.
728+
729+
730+
### cudnn Flex Attention API
731+
732+
SDPA and SDPA_backward ops now accept functors `set_score_mod` and `set_score_mod_bprop`, which allows modification of the attention score matrix. This function can be used to program a sub-graph of pointwise operations that can be subsequently used to program the score modifier. Note that this function usage is exclusive to the usage of ready made options.

include/cudnn_backend_base.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ namespace cudnn_frontend {
3030
/// OpaqueBackendPointer class
3131
/// Holds the raws pointer to backend_descriptor
3232
/// Usage is to wrap this into a smart pointer as
33-
/// it helps to create and destroy the backencpointer
33+
/// it helps to create and destroy the backendpointer
34+
3435
class OpaqueBackendPointer {
3536
cudnnBackendDescriptor_t m_desc = nullptr; //!< Raw void pointer
3637
cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< status of creation of the Descriptor
@@ -153,7 +154,7 @@ class BackendDescriptor {
153154
: pointer(pointer_), status(status_), err_msg(err_msg_) {}
154155
BackendDescriptor() = default;
155156

156-
virtual ~BackendDescriptor(){};
157+
virtual ~BackendDescriptor() {};
157158

158159
ManagedOpaqueDescriptor pointer; //! Shared pointer of the OpaqueBackendPointer
159160

include/cudnn_frontend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
#include "cudnn_frontend/graph_interface.h"
124124
#include "cudnn_frontend/utils/serialize.h"
125125
#include "cudnn_frontend/backend/kernel_cache.h"
126+
#include "cudnn_frontend/utils/attn_score_modifiers.h"
126127

127128
#include "cudnn_frontend_version.h"
128129

include/cudnn_frontend/graph_helpers.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
/*
2+
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a
5+
* copy of this software and associated documentation files (the "Software"),
6+
* to deal in the Software without restriction, including without limitation
7+
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8+
* and/or sell copies of the Software, and to permit persons to whom the
9+
* Software is furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17+
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20+
* DEALINGS IN THE SOFTWARE.
21+
*/
22+
123
#pragma once
224

325
#include <unordered_map>
@@ -31,8 +53,8 @@ enum class [[nodiscard]] error_code_t {
3153
typedef struct [[nodiscard]] error_object {
3254
error_code_t code;
3355
std::string err_msg;
34-
error_object() : code(error_code_t::OK), err_msg(""){};
35-
error_object(error_code_t err, std::string msg) : code(err), err_msg(msg){};
56+
error_object() : code(error_code_t::OK), err_msg("") {};
57+
error_object(error_code_t err, std::string msg) : code(err), err_msg(msg) {};
3658

3759
error_code_t
3860
get_code() {

include/cudnn_frontend/graph_interface.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class Graph : public ICudnn, public INode {
7878
RETURN_CUDNN_FRONTEND_ERROR_IF(((is_dynamic_shape_enabled == false) && (kernel_cache != nullptr)),
7979
error_code_t::GRAPH_NOT_SUPPORTED,
8080
"Kernel caching enabled but dynamic shapes is disabled");
81+
if (detail::get_backend_version() != detail::get_compiled_version()) {
82+
CUDNN_FE_LOG_LABEL_ENDL("INFO: The cuDNN version used at compilation ("
83+
<< detail::get_compiled_version() << ") and the one used at runtime ("
84+
<< detail::get_backend_version() << ") differ.");
85+
}
8186
return {error_code_t::OK, ""};
8287
}
8388

@@ -311,6 +316,7 @@ class Graph : public ICudnn, public INode {
311316
vec_data.data(),
312317
vec_data.size() * sizeof(float),
313318
cudaMemcpyHostToDevice));
319+
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
314320
}
315321
// 1 means memset
316322
else if (operation_type == 1) {
@@ -436,6 +442,7 @@ class Graph : public ICudnn, public INode {
436442
vec_data.data(),
437443
vec_data.size() * sizeof(float),
438444
cudaMemcpyHostToDevice));
445+
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
439446
}
440447
// 1 means memset
441448
else if (operation_type == 1) {
@@ -1320,6 +1327,21 @@ class Graph : public ICudnn, public INode {
13201327
CHECK_TENSORS(sdpa_fp8_attributes);
13211328
FILL_GLOBAL_IO_TENSOR_MAP(sdpa_fp8_attributes);
13221329
sub_nodes.emplace_back(std::make_unique<SDPAFP8Node>(std::move(sdpa_fp8_attributes), context));
1330+
} else if (tag == "RESAMPLE") {
1331+
auto resample_attributes = j_sub_node.get<Resample_attributes>();
1332+
CHECK_TENSORS(resample_attributes);
1333+
FILL_GLOBAL_IO_TENSOR_MAP(resample_attributes);
1334+
sub_nodes.emplace_back(std::make_unique<ResampleNode>(std::move(resample_attributes), context));
1335+
} else if (tag == "CONV_DGRAD") {
1336+
auto dgrad_attributes = j_sub_node.get<Conv_dgrad_attributes>();
1337+
CHECK_TENSORS(dgrad_attributes);
1338+
FILL_GLOBAL_IO_TENSOR_MAP(dgrad_attributes);
1339+
sub_nodes.emplace_back(std::make_unique<DgradNode>(std::move(dgrad_attributes), context));
1340+
} else if (tag == "CONV_WGRAD") {
1341+
auto wgrad_attributes = j_sub_node.get<Conv_wgrad_attributes>();
1342+
CHECK_TENSORS(wgrad_attributes);
1343+
FILL_GLOBAL_IO_TENSOR_MAP(wgrad_attributes);
1344+
sub_nodes.emplace_back(std::make_unique<WgradNode>(std::move(wgrad_attributes), context));
13231345
}
13241346
}
13251347
#undef CHECK_TENSORS
@@ -1699,6 +1721,9 @@ Graph::conv_fprop(std::shared_ptr<Tensor_attributes> x,
16991721
std::shared_ptr<Tensor_attributes> w,
17001722
Conv_fprop_attributes attributes) {
17011723
// Make required output tensors
1724+
if (attributes.name.empty()) {
1725+
attributes.name += std::to_string(sub_nodes.size());
1726+
}
17021727
auto Y = output_tensor(attributes.name + "::Y");
17031728
attributes.outputs[Conv_fprop_attributes::output_names::Y] = Y;
17041729

@@ -1718,6 +1743,9 @@ Graph::dbn_weight(std::shared_ptr<Tensor_attributes> dy,
17181743
std::shared_ptr<Tensor_attributes> inv_variance,
17191744
std::shared_ptr<Tensor_attributes> scale,
17201745
DBN_weight_attributes attributes) {
1746+
if (attributes.name.empty()) {
1747+
attributes.name += std::to_string(sub_nodes.size());
1748+
}
17211749
// Make required output tensors
17221750
auto DBIAS = attributes.outputs[DBN_weight_attributes::output_names::DBIAS] =
17231751
output_tensor(attributes.name + "::DBIAS");

include/cudnn_frontend/graph_properties.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ class Resample_attributes : public Attributes<Resample_attributes> {
11031103
name,
11041104
inputs,
11051105
outputs,
1106+
is_inference,
11061107
resample_mode,
11071108
padding_mode,
11081109
pre_padding,
@@ -1407,6 +1408,11 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
14071408
friend class SDPANode;
14081409
friend class Graph;
14091410

1411+
using Tensor_t = std::shared_ptr<Tensor_attributes>;
1412+
using Graph_t = std::shared_ptr<Graph>;
1413+
1414+
using AttentionScoreModifier_t = std::function<Tensor_t(Graph_t, Tensor_t)>;
1415+
14101416
std::optional<bool> is_inference;
14111417
bool alibi_mask = false;
14121418
bool padding_mask = false;
@@ -1416,6 +1422,7 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
14161422
std::optional<float> dropout_probability;
14171423
std::optional<float> attn_scale_value;
14181424
std::optional<int> max_seq_len_kv;
1425+
AttentionScoreModifier_t attention_score_modifier = nullptr;
14191426

14201427
public:
14211428
enum class input_names {
@@ -1509,6 +1516,12 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
15091516
return *this;
15101517
}
15111518

1519+
SDPA_attributes&
1520+
set_score_mod(AttentionScoreModifier_t fn) {
1521+
attention_score_modifier = std::move(fn);
1522+
return *this;
1523+
}
1524+
15121525
SDPA_attributes&
15131526
set_sliding_window_length(int const value) {
15141527
sliding_window_length = value;
@@ -1675,6 +1688,10 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
16751688
friend class Attributes<SDPA_backward_attributes>;
16761689
friend class SDPABackwardNode;
16771690
friend class Graph;
1691+
using Tensor_t = std::shared_ptr<Tensor_attributes>;
1692+
using Graph_t = std::shared_ptr<Graph>;
1693+
1694+
using AttentionScoreModifier_t = std::function<Tensor_t(Graph_t, Tensor_t)>;
16781695

16791696
bool alibi_mask = false;
16801697
bool padding_mask = false;
@@ -1688,7 +1705,9 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
16881705
std::optional<int64_t> max_total_seq_len_q;
16891706
std::optional<int64_t> max_total_seq_len_kv;
16901707

1691-
bool is_deterministic_algorithm = false;
1708+
bool is_deterministic_algorithm = false;
1709+
AttentionScoreModifier_t attention_score_modifier = nullptr;
1710+
AttentionScoreModifier_t attention_score_modifier_bprop = nullptr;
16921711

16931712
public:
16941713
enum class input_names {
@@ -1760,6 +1779,18 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
17601779
return *this;
17611780
}
17621781

1782+
SDPA_backward_attributes&
1783+
set_score_mod(AttentionScoreModifier_t fn) {
1784+
attention_score_modifier = std::move(fn);
1785+
return *this;
1786+
}
1787+
1788+
SDPA_backward_attributes&
1789+
set_score_mod_bprop(AttentionScoreModifier_t fn) {
1790+
attention_score_modifier_bprop = std::move(fn);
1791+
return *this;
1792+
}
1793+
17631794
SDPA_backward_attributes&
17641795
set_seq_len_q(std::shared_ptr<Tensor_attributes> value) {
17651796
inputs[SDPA_backward_attributes::input_names::SEQ_LEN_Q] = value;

include/cudnn_frontend/node/paged_cache_load.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class PagedCacheLoadNode : public NodeCRTP<PagedCacheLoadNode> {
7575
error_t
7676
pre_validate_node() const override final {
7777
CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating PagedCacheLoadNode " << attributes.name << "...");
78+
79+
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 || detail::get_compiled_version() < 90500,
80+
error_code_t::CUDNN_BACKEND_API_FAILED,
81+
"The cuDNN backend version must be at least 9.5.0 at compile time and runtime "
82+
"in order to use PagedCacheLoadNode.");
83+
7884
auto const yOut_dims = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_dim();
7985
auto const yOut_strides = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_stride();
8086
auto const container_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::container)->get_dim();

include/cudnn_frontend/node/resample.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ class ResampleNode : public NodeCRTP<ResampleNode> {
169169

170170
inline std::array<std::shared_ptr<Tensor_attributes>, 2>
171171
INode::resample(std::shared_ptr<Tensor_attributes> input, Resample_attributes attributes) {
172+
if (attributes.name.empty()) {
173+
attributes.name += std::to_string(sub_nodes.size());
174+
}
172175
attributes.inputs[Resample_attributes::input_names::X] = input;
173176
auto Y = attributes.outputs[Resample_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
174177
std::shared_ptr<Tensor_attributes> Index = nullptr;

0 commit comments

Comments
 (0)