Skip to content

Commit 666996f

Browse files
authored
# v1.12.0 release (#141)
## cudnn frontend v1.12 release notes cudnn frontend v1.12 is the preferred cudnn frontend version for cudnn version 9.9.0 and above. cudnn_frontend v1.12 is the minimum cudnn frontend version required to work with cuda 13.0 and above Update the dlpack version and cmake minimum required version to be 3.18 ## New API - Allows compilation and loading of cudnn frontend with cudnn-jit packages. - Introduce Adaptive Layernorm (fprop and bprop) operation in cudnn. ``` std::array<std::shared_ptr<Tensor_attributes>, 3> adalayernorm(std::shared_ptr<Tensor_attributes>& input, std::shared_ptr<Tensor_attributes>& scale, std::shared_ptr<Tensor_attributes>& bias, AdaLayernorm_attributes attributes); std::array<std::shared_ptr<Tensor_attributes>, 3> adalayernorm_backward( std::shared_ptr<Tensor_attributes> dy, std::shared_ptr<Tensor_attributes> x, std::shared_ptr<Tensor_attributes> scale, AdaLayernorm_backward_attributes options); ``` Please refer to [samples](samples/cpp/norm/adaptive_layernorm.cpp) for usage. - cudnn frontend python API introduces two decorator function `cudnn.jit` and `cudnn.graph` for simpler graph creation in python. Refer the [matmul sample](samples/python/01_matmul_bias.ipynb) for usage. ## Improvements ### SDPA - Allows large embedded dimension (d > 128) for fprop across Ampere, Hopper, and Blackwell architectures for bf16/fp16. - Added better validation checks for sliding window attention for cudnn version 9.9.0 and below. - Sliding windown attention now supports cases when s_q > s_kv - sdpa_fp8 operation now pads correctly with negative infinity on masking operation rather than high negative value. This improves the numerical stability of the sdpa operation with fp8 data type. - Paged attention now supports page tables in a packed format ### Normalizations - Allow zero-centered scale in layer norm. Refer to this [sample](samples/cpp/norm/norm_zero_centered_gamma.cpp) for usage. ### Others - cudnn frontend now supports serialization of dynamic kernel cache. ## Bug Fixes - Fixed the dlopen of cudart.so to look for the binary with version name. - Correctly fail when SDPA bprop is called on Blackwell with embedded dimension (d) > 128.
1 parent 8801fd7 commit 666996f

File tree

69 files changed

+15810
-9987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+15810
-9987
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
cmake_minimum_required(VERSION 3.17)
1+
cmake_minimum_required(VERSION 3.18)
22

3-
project(cudnn_frontend VERSION 1.11.0)
3+
project(cudnn_frontend VERSION 1.12.0)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

cmake/cuDNN.cmake

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header
1212
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
1313

1414
function(find_cudnn_library NAME)
15+
if(NOT "${ARGV1}" STREQUAL "OPTIONAL")
16+
set(_cudnn_required "REQUIRED")
17+
else()
18+
set(_cudnn_required "")
19+
endif()
20+
1521
find_library(
1622
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
1723
HINTS $ENV{CUDNN_LIBRARY_PATH} ${CUDNN_LIBRARY_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
1824
PATH_SUFFIXES lib64 lib/x64 lib
19-
REQUIRED
25+
${_cudnn_required}
2026
)
2127

2228
if(${NAME}_LIBRARY)
@@ -30,8 +36,6 @@ function(find_cudnn_library NAME)
3036
else()
3137
message(STATUS "${NAME} not found.")
3238
endif()
33-
34-
3539
endfunction()
3640

3741
find_cudnn_library(cudnn)
@@ -87,22 +91,22 @@ if(CUDNN_MAJOR_VERSION EQUAL 8)
8791
CUDNN::cudnn_ops_infer
8892
)
8993
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
90-
find_cudnn_library(cudnn_cnn)
91-
find_cudnn_library(cudnn_adv)
9294
find_cudnn_library(cudnn_graph)
93-
find_cudnn_library(cudnn_ops)
9495
find_cudnn_library(cudnn_engines_runtime_compiled)
95-
find_cudnn_library(cudnn_engines_precompiled)
96-
find_cudnn_library(cudnn_heuristic)
96+
find_cudnn_library(cudnn_ops OPTIONAL)
97+
find_cudnn_library(cudnn_cnn OPTIONAL)
98+
find_cudnn_library(cudnn_adv OPTIONAL)
99+
find_cudnn_library(cudnn_engines_precompiled OPTIONAL)
100+
find_cudnn_library(cudnn_heuristic OPTIONAL)
97101

98102
target_link_libraries(
99103
CUDNN::cudnn_all
100104
INTERFACE
101-
CUDNN::cudnn_adv
102-
CUDNN::cudnn_ops
103-
CUDNN::cudnn_cnn
104105
CUDNN::cudnn_graph
105106
CUDNN::cudnn_engines_runtime_compiled
107+
CUDNN::cudnn_ops
108+
CUDNN::cudnn_cnn
109+
CUDNN::cudnn_adv
106110
CUDNN::cudnn_engines_precompiled
107111
CUDNN::cudnn_heuristic
108112
)

include/cudnn_frontend/backend/kernel_cache.h

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,68 @@ class KernelCache : public detail::backend_descriptor {
6868
return {error_code_t::CUDNN_BACKEND_API_FAILED,
6969
"CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR: Check CUDNN_VERSION >= 9.4"};
7070
}
71-
return {error_code_t::OK, ""};
71+
return {};
72+
}
73+
74+
error_t
75+
to_json(std::string &str_json) const {
76+
str_json.clear();
77+
#if (CUDNN_VERSION >= 91000)
78+
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91000,
79+
error_code_t::CUDNN_BACKEND_API_FAILED,
80+
"CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10.");
81+
82+
int64_t serializationSize;
83+
std::vector<char> serialization_buf;
84+
CHECK_CUDNN_ERROR(detail::get_attribute(
85+
get_ptr(), CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION, CUDNN_TYPE_CHAR, 0, &serializationSize, nullptr));
86+
serialization_buf.resize(static_cast<size_t>(serializationSize));
87+
88+
CHECK_CUDNN_ERROR(detail::get_attribute(get_ptr(),
89+
CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION,
90+
CUDNN_TYPE_CHAR,
91+
serializationSize,
92+
&serializationSize,
93+
serialization_buf.data()));
94+
std::string json_string(serialization_buf.begin(), serialization_buf.end());
95+
str_json = json_string;
96+
return {};
97+
#else
98+
(void)str_json;
99+
return {error_code_t::CUDNN_BACKEND_API_FAILED,
100+
"CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."};
101+
#endif
102+
}
103+
104+
error_t
105+
from_json(const std::string &json_cache) {
106+
#if (CUDNN_VERSION >= 91000)
107+
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91000,
108+
error_code_t::CUDNN_BACKEND_API_FAILED,
109+
"CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10.");
110+
111+
// Check if the kernel cache is already initialized
112+
RETURN_CUDNN_FRONTEND_ERROR_IF(
113+
get_ptr() != nullptr, error_code_t::CUDNN_BACKEND_API_FAILED, "Kernel cache is already initialized.");
114+
115+
// // Initialize the kernel cache descriptor
116+
CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR));
117+
118+
std::vector<char> serialization_buf;
119+
serialization_buf.assign(json_cache.begin(), json_cache.end());
120+
CHECK_CUDNN_ERROR(detail::set_attribute(get_ptr(),
121+
CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION,
122+
CUDNN_TYPE_CHAR,
123+
serialization_buf.size(),
124+
serialization_buf.data()));
125+
return {};
126+
#else
127+
(void)json_cache;
128+
return {error_code_t::CUDNN_BACKEND_API_FAILED,
129+
"CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."};
130+
#endif
72131
}
73132

74-
private:
75133
// Responsible for initializing, setting operation graph attribute, and finalizing kernel cache
76134
// Check for both compile-time and runtime cuDNN version
77135
error_t
@@ -80,26 +138,31 @@ class KernelCache : public detail::backend_descriptor {
80138
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90400,
81139
error_code_t::GRAPH_NOT_SUPPORTED,
82140
"CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4.");
83-
CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR));
141+
if (get_ptr() == nullptr) {
142+
CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR));
143+
}
84144
#if (CUDNN_VERSION >= 90500)
85145
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500,
86146
error_code_t::GRAPH_NOT_SUPPORTED,
87147
"CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH is only available starting 9.5.");
88-
CHECK_CUDNN_ERROR(detail::set_attribute(
89-
get_ptr(), CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph));
148+
if (op_graph) {
149+
CHECK_CUDNN_ERROR(detail::set_attribute(
150+
get_ptr(), CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph));
151+
}
90152
#else
91153
(void)op_graph;
92154
#endif
93155
CHECK_CUDNN_FRONTEND_ERROR(finalize());
94156
finalized = true;
95-
return {error_code_t::OK, ""};
157+
return {};
96158
#else
97159
(void)op_graph;
98160
return {error_code_t::CUDNN_BACKEND_API_FAILED,
99161
"CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4."};
100162
#endif
101163
}
102164

165+
private:
103166
bool finalized = false;
104167
};
105168
} // namespace cudnn_frontend

include/cudnn_frontend/graph_interface.h

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "node/dbn_weight.h"
1616
#include "node/genstats.h"
1717
#include "node/layernorm.h"
18+
#include "node/adaptive_layernorm.h"
1819
#include "node/instancenorm.h"
1920
#include "node/rmsnorm.h"
2021
#include "node/resample.h"
@@ -557,7 +558,6 @@ class Graph : public ICudnn, public INode {
557558

558559
// Validate the nodes, which in turn also infers missing tensor attributes.
559560
CHECK_CUDNN_FRONTEND_ERROR(validate_subtree());
560-
561561
// Validate all outputs, which should now have everything set to be lowered to backend.
562562
for (auto const &output : full_graph_outputs) {
563563
CHECK_CUDNN_FRONTEND_ERROR(output->validate());
@@ -914,6 +914,11 @@ class Graph : public ICudnn, public INode {
914914
std::shared_ptr<Tensor_attributes>,
915915
Layernorm_attributes);
916916

917+
std::array<std::shared_ptr<Tensor_attributes>, 3> adalayernorm(std::shared_ptr<Tensor_attributes>,
918+
std::shared_ptr<Tensor_attributes>,
919+
std::shared_ptr<Tensor_attributes>,
920+
AdaLayernorm_attributes);
921+
917922
std::array<std::shared_ptr<Tensor_attributes>, 3> instancenorm(std::shared_ptr<Tensor_attributes>,
918923
std::shared_ptr<Tensor_attributes>,
919924
std::shared_ptr<Tensor_attributes>,
@@ -968,6 +973,11 @@ class Graph : public ICudnn, public INode {
968973
std::shared_ptr<Tensor_attributes>,
969974
Layernorm_backward_attributes);
970975

976+
std::array<std::shared_ptr<Tensor_attributes>, 3> adalayernorm_backward(std::shared_ptr<Tensor_attributes>,
977+
std::shared_ptr<Tensor_attributes>,
978+
std::shared_ptr<Tensor_attributes>,
979+
AdaLayernorm_backward_attributes);
980+
971981
std::array<std::shared_ptr<Tensor_attributes>, 3> instancenorm_backward(std::shared_ptr<Tensor_attributes>,
972982
std::shared_ptr<Tensor_attributes>,
973983
std::shared_ptr<Tensor_attributes>,
@@ -1182,7 +1192,6 @@ class Graph : public ICudnn, public INode {
11821192
j["nodes"];
11831193
j["tensors"];
11841194
std::unordered_set<std::string> tensors;
1185-
11861195
for (const auto &sub_node : full_json["nodes"]) {
11871196
// Create a short version of the node
11881197
auto short_node = sub_node;
@@ -1212,7 +1221,6 @@ class Graph : public ICudnn, public INode {
12121221
}
12131222

12141223
std::string tensor_name = tensor_info["name"].get<std::string>();
1215-
12161224
// Update short_node inputs
12171225
short_node["inputs"][port_name] = tensor_name;
12181226

@@ -1699,6 +1707,31 @@ Graph::layernorm(std::shared_ptr<Tensor_attributes> x,
16991707
return {Y, MEAN, INV_VARIANCE};
17001708
}
17011709

1710+
inline std::array<std::shared_ptr<Tensor_attributes>, 3>
1711+
Graph::adalayernorm(std::shared_ptr<Tensor_attributes> x,
1712+
std::shared_ptr<Tensor_attributes> scale,
1713+
std::shared_ptr<Tensor_attributes> bias,
1714+
AdaLayernorm_attributes attributes) {
1715+
// Set outputs
1716+
auto Y = attributes.outputs[AdaLayernorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
1717+
std::shared_ptr<Tensor_attributes> MEAN = nullptr;
1718+
std::shared_ptr<Tensor_attributes> INV_VARIANCE = nullptr;
1719+
if (attributes.forward_phase == NormFwdPhase_t::TRAINING) {
1720+
MEAN = attributes.outputs[AdaLayernorm_attributes::output_names::MEAN] =
1721+
output_tensor(attributes.name + "::MEAN");
1722+
INV_VARIANCE = attributes.outputs[AdaLayernorm_attributes::output_names::INV_VARIANCE] =
1723+
output_tensor(attributes.name + "::INV_VARIANCE");
1724+
}
1725+
// Set inputs
1726+
attributes.inputs[AdaLayernorm_attributes::input_names::X] = x;
1727+
attributes.inputs[AdaLayernorm_attributes::input_names::SCALE] = scale;
1728+
attributes.inputs[AdaLayernorm_attributes::input_names::BIAS] = bias;
1729+
1730+
sub_nodes.emplace_back(std::make_unique<AdaLayerNormNode>(std::move(attributes), context));
1731+
1732+
return {Y, MEAN, INV_VARIANCE};
1733+
}
1734+
17021735
inline std::array<std::shared_ptr<Tensor_attributes>, 3>
17031736
Graph::instancenorm(std::shared_ptr<Tensor_attributes> x,
17041737
std::shared_ptr<Tensor_attributes> scale,
@@ -1848,6 +1881,28 @@ Graph::layernorm_backward(std::shared_ptr<Tensor_attributes> dy,
18481881
return {DX, DSCALE, DBIAS};
18491882
}
18501883

1884+
inline std::array<std::shared_ptr<Tensor_attributes>, 3>
1885+
Graph::adalayernorm_backward(std::shared_ptr<Tensor_attributes> dy,
1886+
std::shared_ptr<Tensor_attributes> x,
1887+
std::shared_ptr<Tensor_attributes> scale,
1888+
AdaLayernorm_backward_attributes attributes) {
1889+
// Set outputs
1890+
auto DX = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DX] =
1891+
output_tensor(attributes.name + "::DX");
1892+
auto DSCALE = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DSCALE] =
1893+
output_tensor(attributes.name + "::DSCALE");
1894+
auto DBIAS = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DBIAS] =
1895+
output_tensor(attributes.name + "::DBIAS");
1896+
// Set inputs
1897+
attributes.inputs[AdaLayernorm_backward_attributes::input_names::DY] = dy;
1898+
attributes.inputs[AdaLayernorm_backward_attributes::input_names::X] = x;
1899+
attributes.inputs[AdaLayernorm_backward_attributes::input_names::SCALE] = scale;
1900+
1901+
sub_nodes.emplace_back(std::make_unique<DAdaLayerNormNode>(std::move(attributes), context));
1902+
1903+
return {DX, DSCALE, DBIAS};
1904+
}
1905+
18511906
inline std::shared_ptr<Tensor_attributes>
18521907
Graph::conv_fprop(std::shared_ptr<Tensor_attributes> x,
18531908
std::shared_ptr<Tensor_attributes> w,

include/cudnn_frontend/graph_properties.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,54 @@ class Layernorm_attributes : public Attributes<Layernorm_attributes> {
943943
}
944944
};
945945

946+
class AdaLayernorm_attributes : public Attributes<AdaLayernorm_attributes> {
947+
friend class Attributes<AdaLayernorm_attributes>;
948+
friend class AdaLayerNormNode;
949+
friend class Graph;
950+
951+
NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET;
952+
953+
public:
954+
enum class input_names { X, SCALE, BIAS, EPSILON };
955+
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
956+
enum class output_names { Y, MEAN, INV_VARIANCE };
957+
std::unordered_map<output_names, std::shared_ptr<Tensor_attributes>> outputs;
958+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(AdaLayernorm_attributes, name, compute_data_type, inputs, outputs, forward_phase)
959+
960+
AdaLayernorm_attributes&
961+
set_forward_phase(NormFwdPhase_t const value) {
962+
forward_phase = value;
963+
return *this;
964+
}
965+
966+
AdaLayernorm_attributes&
967+
set_epsilon(std::shared_ptr<Tensor_attributes>& value) {
968+
inputs[AdaLayernorm_attributes::input_names::EPSILON] = value;
969+
return *this;
970+
}
971+
};
972+
973+
class AdaLayernorm_backward_attributes : public Attributes<AdaLayernorm_backward_attributes> {
974+
friend class Attributes<AdaLayernorm_backward_attributes>;
975+
friend class DAdaLayerNormNode;
976+
friend class Graph;
977+
978+
public:
979+
enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE, EPSILON };
980+
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
981+
enum class output_names { DX, DSCALE, DBIAS };
982+
std::unordered_map<output_names, std::shared_ptr<Tensor_attributes>> outputs;
983+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(AdaLayernorm_backward_attributes, name, compute_data_type, inputs, outputs)
984+
985+
AdaLayernorm_backward_attributes&
986+
set_saved_mean_and_inv_variance(std::shared_ptr<Tensor_attributes> mean,
987+
std::shared_ptr<Tensor_attributes> inv_variance) {
988+
inputs[AdaLayernorm_backward_attributes::input_names::MEAN] = mean;
989+
inputs[AdaLayernorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance;
990+
return *this;
991+
}
992+
};
993+
946994
class Instancenorm_attributes : public Attributes<Instancenorm_attributes> {
947995
friend class Attributes<Instancenorm_attributes>;
948996
friend class InstanceNormNode;

0 commit comments

Comments
 (0)