Skip to content

Commit 8801fd7

Browse files
authored
# v1.11.0 release (#136)
## cudnn frontend v1.11 release notes cudnn frontend v1.11 is the preferred cudnn frontend version for cudnn version 9.8.0 and above. With cuDNN frontend v1.11, the minimum supported cudnn version is 9.0.0. ## New API - cudnn frontend v1.11 release flexible score modifier to the python SDPA API. Samples showcasing soft cap of the attention scores, arrow mask are available in the [cudnn_frontend/test/python/test_flexible_sdpa.py](https://github.com/NVIDIA/cuDNN-frontend/blob/main/cudnn_frontend/test/python/test_flexible_sdpa.py) file. A sample usage of score modifier is shown below: ``` score_mod=partial( custom_mask, mod_tensor=mod_tensor, neg_inf=neg_inf_tensor, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, ) ``` - The Concatenate operation merges two or more tensors into one, along the specified axis. The user may also specify an in-place merge. ``` std::shared_ptr<Tensor_attributes> concatenate(std::vector<std::shared_ptr<Tensor_attributes>>, Concatenate_attributes); ``` - pip wheels compatible with windows x86_64 architecture are now available on [pypi](https://pypi.org/project/nvidia-cudnn-frontend/). - sdpa paged attention API now supports Q tensor to be ragged when used with cudnn version 9.7.0 and above. ## Improvements - Users can now pass the CMake flag `-DCMAKE_CXX_FLAGS="-DNV_CUDNN_FRONTEND_DISABLE_LOGGING"` to disable logging in the cuDNN frontend. - Added a new sample to showcase native cudagraph creation from cudnn for sdpa bprop operation. Fixed a bug when using the update_cuda_graph API to update cuda graph for sdpa bprop operation. ## Bug Fixes - Fixed memory leak in the test harness for some legacy tests that use ragged tensors. - Fixed a bug introduced in the benchmarking script that prevented the sdpa cudnn operation from being executed. This was because the `use_padding_mask` attribute was made mandatory for the sdpa operation. This has been fixed as well. - Updated the paged attention sample to not cause illegal memory access when changing the dimensions of the tensors in the sample. - Updated the DgradDReluBNBwdWeight sample to perform the right operation for the dgrad + drelu fusion.
1 parent 5040925 commit 8801fd7

36 files changed

+2276
-372
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.17)
22

3-
project(cudnn_frontend VERSION 1.10.0)
3+
project(cudnn_frontend VERSION 1.11.0)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

include/cudnn_frontend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
* The cuDNN Frontend API is a C++ header-only library that demonstrates how to use the cuDNN C backend API. The cuDNN C
3030
* backend API is documented in the cuDNN developer guide.
3131
*
32-
* \section Need Why use Frontend API
32+
* \section Why use Frontend API
3333
*
3434
* Consider the following code snippet which showcases cudnnBackendTensor creation using the backend API and its
3535
* equivalent front-end API code. Many among the backend constructs follow similar pattern.

include/cudnn_frontend/graph_interface.h

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <unordered_map>
4+
#include <string>
45

56
#include "../cudnn_frontend_version.h"
67
#include "node/batchnorm.h"
@@ -25,6 +26,7 @@
2526
#include "node/sdpa_fp8_bwd.h"
2627
#include "node/block_scale_quantize.h"
2728
#include "node/block_scale_dequantize.h"
29+
#include "node/concatenate.h"
2830

2931
#include "backend/backend_descriptor.h"
3032
#include "plans.h"
@@ -293,8 +295,8 @@ class Graph : public ICudnn, public INode {
293295
CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset));
294296

295297
for (auto const &[uid, data] : workspace_modifications) {
296-
(void)uid;
297298
const auto &[operation_type, offset, vec_data] = data;
299+
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
298300

299301
// 0 means memcpy
300302
if (operation_type == 0) {
@@ -304,7 +306,6 @@ class Graph : public ICudnn, public INode {
304306
vec_data.data(),
305307
vec_data.size() * sizeof(float),
306308
cudaMemcpyHostToDevice));
307-
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
308309
}
309310
// 1 means memset
310311
else if (operation_type == 1) {
@@ -322,12 +323,19 @@ class Graph : public ICudnn, public INode {
322323

323324
CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node_set_params(current_node, &params));
324325
}
325-
// Other values do not correspond to cuda APIs
326+
// Other values do not correspond to CUDA graph nodes
327+
else {
328+
continue;
329+
}
326330

327-
CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_root_nodes));
331+
size_t num_dependent_nodes;
332+
CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_dependent_nodes));
328333
RETURN_CUDNN_FRONTEND_ERROR_IF(
329-
num_root_nodes != 1, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should have exactly 1 root node.");
330-
CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, &current_node, &num_root_nodes));
334+
num_dependent_nodes != 1,
335+
error_code_t::INVALID_VALUE,
336+
"Each node of cudnn_cuda_graph before the backend graph node should have exactly 1 dependent node.");
337+
CHECK_CUDA_ERROR(
338+
detail::cuda_graph_node_get_dependent_nodes(current_node, &current_node, &num_dependent_nodes));
331339
}
332340

333341
// Make sure device pointer is provided for all uids expected for this plan
@@ -357,7 +365,10 @@ class Graph : public ICudnn, public INode {
357365
error_code_t::CUDNN_BACKEND_API_FAILED,
358366
"Failed to create variant pack's backend descriptor.");
359367

360-
CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, workspace));
368+
// offset workspace by the already used fe graph workspace
369+
// this is where cudnn backend can start using workspace for its execution plans
370+
void *cudnn_workspace = static_cast<char *>(workspace) + fe_workspace_size;
371+
CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, cudnn_workspace));
361372

362373
int64_t candidate = plans.candidate;
363374
CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(candidate));
@@ -367,8 +378,9 @@ class Graph : public ICudnn, public INode {
367378
backend_cuda_graph));
368379

369380
// There should be nothing after the backend graph
370-
CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_root_nodes));
371-
RETURN_CUDNN_FRONTEND_ERROR_IF(num_root_nodes != 0,
381+
size_t num_dependent_nodes;
382+
CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_dependent_nodes));
383+
RETURN_CUDNN_FRONTEND_ERROR_IF(num_dependent_nodes != 0,
372384
error_code_t::INVALID_VALUE,
373385
"cudnn_cuda_graph should have no graph nodes after the backend graph node.");
374386

@@ -431,8 +443,8 @@ class Graph : public ICudnn, public INode {
431443
CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset));
432444

433445
for (auto const &[uid, data] : workspace_modifications) {
434-
(void)uid;
435446
const auto &[operation_type, offset, vec_data] = data;
447+
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
436448

437449
cudaGraphNode_t node = nullptr;
438450

@@ -446,7 +458,6 @@ class Graph : public ICudnn, public INode {
446458
vec_data.data(),
447459
vec_data.size() * sizeof(float),
448460
cudaMemcpyHostToDevice));
449-
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
450461
}
451462
// 1 means memset
452463
else if (operation_type == 1) {
@@ -465,7 +476,10 @@ class Graph : public ICudnn, public INode {
465476
CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node(
466477
&node, cudnn_cuda_graph, &last_node, last_node != nullptr, &params));
467478
}
468-
// Other values do not correspond to cuda APIs
479+
// Other values do not correspond to CUDA graph nodes
480+
else {
481+
continue;
482+
}
469483

470484
last_node = node;
471485
}
@@ -495,7 +509,11 @@ class Graph : public ICudnn, public INode {
495509
RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS,
496510
error_code_t::CUDNN_BACKEND_API_FAILED,
497511
"Failed to create variant pack's backend descriptor.");
498-
CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, workspace));
512+
513+
// offset workspace by the already used fe graph workspace
514+
// this is where cudnn backend can start using workspace for its execution plans
515+
void *cudnn_workspace = static_cast<char *>(workspace) + fe_workspace_size;
516+
CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, cudnn_workspace));
499517

500518
// Get the plan candidate. It only makes to sense to make cuda graph after execution plan has been built.
501519
// And in that case the candidate would have been set.
@@ -1019,6 +1037,9 @@ class Graph : public ICudnn, public INode {
10191037
std::shared_ptr<Tensor_attributes>,
10201038
Block_scale_dequantize_attributes);
10211039

1040+
std::shared_ptr<Tensor_attributes> concatenate(std::vector<std::shared_ptr<Tensor_attributes>>,
1041+
Concatenate_attributes);
1042+
10221043
[[deprecated]] std::array<std::shared_ptr<Tensor_attributes>, 2>
10231044
scaled_dot_product_flash_attention(std::shared_ptr<Tensor_attributes> q,
10241045
std::shared_ptr<Tensor_attributes> k,
@@ -1168,11 +1189,23 @@ class Graph : public ICudnn, public INode {
11681189
short_node["inputs"] = {};
11691190
short_node["outputs"] = {};
11701191

1192+
auto node_name = sub_node["tag"].get<std::string>();
1193+
auto i = 0;
11711194
// Process node inputs
11721195
for (const auto &input : sub_node["inputs"]) {
1173-
// Extract port_name and tensor_name
1174-
auto port_name = input[0].get<std::string>();
1175-
auto tensor_info = input[1];
1196+
std::string port_name;
1197+
json tensor_info;
1198+
1199+
if (node_name == "CONCATENATE") {
1200+
// Extract port_name and tensor_name
1201+
port_name = std::to_string(i);
1202+
tensor_info = input;
1203+
i++;
1204+
} else {
1205+
// Extract port_name and tensor_name
1206+
port_name = input[0].get<std::string>();
1207+
tensor_info = input[1];
1208+
}
11761209

11771210
if (tensor_info.is_null()) {
11781211
continue;
@@ -2161,10 +2194,29 @@ Graph::block_scale_dequantize(std::shared_ptr<Tensor_attributes> x,
21612194
return Y;
21622195
}
21632196

2197+
inline std::shared_ptr<Tensor_attributes>
2198+
Graph::concatenate(std::vector<std::shared_ptr<Tensor_attributes>> x, Concatenate_attributes attributes) {
2199+
if (attributes.name.empty()) {
2200+
attributes.name += std::to_string(sub_nodes.size());
2201+
}
2202+
2203+
// Set outputs
2204+
auto Y = attributes.outputs[Concatenate_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
2205+
2206+
// Set inputs
2207+
for (auto &element : x) {
2208+
attributes.inputs.push_back(element);
2209+
}
2210+
2211+
sub_nodes.emplace_back(std::make_unique<ConcatenateNode>(std::move(attributes), context));
2212+
2213+
return Y;
2214+
}
2215+
21642216
static inline std::ostream &
21652217
operator<<(std::ostream &os, Graph const &graph) {
21662218
os << graph.print();
21672219
return os;
21682220
}
21692221

2170-
} // namespace cudnn_frontend::graph
2222+
} // namespace cudnn_frontend::graph

0 commit comments

Comments
 (0)