Skip to content

Commit 747a63a

Browse files
committed
cuDNN Frontend v1.16.0 is the recommended version for [cuDNN 9.15.0](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-15-0) and later releases.
This release introduces open-source implementations of commonly requested fused kernels for select architectures (Blackwell). These experimental kernels may require additional dependencies such as CuteDSL. The initial release includes: - [GEMM + Amax](gemm_fusions/gemm_amax.md) - [GEMM + SwiGLU](gemm_fusions/gemm_swiglu.md) Additional dependencies can be installed optionally using `pip install nvidia-cudnn-frontend[cutedsl]`. Usage examples and detailed documentation are available in the [test/python/fe_api](test/python/fe_api) directory. Please submit issue reports for additional kernel requests or bug reports. - **Block Mask Support**: Starting with cuDNN 9.14.0, SDPA attributes now support block masks to exclude tiles that do not require computation. Refer to the [sample implementation](samples/cpp/sdpa/fp16_fwd_with_block_mask.cpp) for usage details. - **Bug Fix**: Resolved an invalid memory access (IMA) issue in SDPA backward propagation (fixed in cuDNN backend version 9.15.1 and later) that occurred when `s_kv` is not a multiple of 128, padding mask is disabled, and operations are performed in CUDA graph replay mode. - **CUDA Graph Compatibility**: Added `BehaviorNote_t::CUDNN_BEHAVIOR_NOTE_CUBLASLT_DEPENDENCY` as a behavior note. This enables filtering of engine configurations (execution plans) that use cuBLAS as a backend, available starting with cuDNN version 9.15.0. - **Block Scale Quantization**: Added Python bindings for block scale quantize operations ([#173](#173)). Refer to the [sample implementation](test/python/test_block_scale_quantize.py) for usage details. - **Dependency Optimization**: PyTorch is no longer a required dependency for cuDNN Frontend ([#177](#177)). - **Tensor Alignment**: Enhanced tensor descriptor API to accept alignment as an attribute ([#153](#153)). - **Plan Generation Control**: Updated `cudnnGetPlan` API to accept an optional maximum plan count parameter, enabling users to limit the number of plans built and autotuned. - Updated [benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py](benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py) to use correct parameter names and fixed FLOPS calculations for accurate performance measurements. - [#153](#153) - Tensor descriptor alignment support - [#173](#173) - Block scale quantize Python bindings - [#177](#177) - PyTorch dependency removal
1 parent b849f20 commit 747a63a

File tree

86 files changed

+8382
-7126
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+8382
-7126
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.23)
22

3-
project(cudnn_frontend VERSION 1.15.0)
3+
project(cudnn_frontend VERSION 1.16.0)
44

55
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
66
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)

benchmark/sdpa_benchmark/benchmark_flash_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def time_fwd(func, *args, **kwargs):
373373

374374
graph_fwd.validate()
375375
graph_fwd.build_operation_graph()
376-
graph_fwd.create_execution_plans([cudnn.heur_mode.A])
376+
graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
377377
graph_fwd.check_support()
378378
graph_fwd.build_plans()
379379

@@ -416,7 +416,7 @@ def time_fwd(func, *args, **kwargs):
416416
if headdim != 256:
417417
graph_bwd.validate()
418418
graph_bwd.build_operation_graph()
419-
graph_bwd.create_execution_plans([cudnn.heur_mode.A])
419+
graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
420420
graph_bwd.check_support()
421421
graph_bwd.build_plans()
422422

@@ -588,7 +588,7 @@ def time_fwd(func, *args, **kwargs):
588588

589589
graph_fwd.validate()
590590
graph_fwd.build_operation_graph()
591-
graph_fwd.create_execution_plans([cudnn.heur_mode.A])
591+
graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
592592
graph_fwd.check_support()
593593
graph_fwd.build_plans()
594594

@@ -671,7 +671,7 @@ def time_fwd(func, *args, **kwargs):
671671
if headdim == 128:
672672
graph_bwd.validate()
673673
graph_bwd.build_operation_graph()
674-
graph_bwd.create_execution_plans([cudnn.heur_mode.A])
674+
graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
675675
graph_bwd.check_support()
676676
graph_bwd.build_plans()
677677

benchmark/sdpa_benchmark_training/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,22 @@ As demonstrated can be seen from the results, cuDNN v9 can achieve over 2x the p
142142
Example commands and outputs:
143143
```
144144
## For running various PyTorch backends (FlashAttention, cuDNN, ...) or FlashAttention-2:
145-
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend pyt_cudnn --data_type bfloat16 --fwd_bwd
145+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend pyt_cudnn --data_type bfloat16 --fwd_bwd
146146
pyt_cudnn:: Median (fwd, bwd) Execution Times: 24.645 ms (1428 TFLOPS), 78.674 ms (1118 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations)
147147
148148
## For directly running cuDNN via cuDNN Frontend
149-
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type bfloat16 --fwd_bwd
149+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type bfloat16 --fwd_bwd
150150
cudnn_fe:: Median (fwd, bwd) Execution Times: 24.543 ms (1434 TFLOPS), 73.210 ms (1201 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations)
151151
152152
## For running cuDNN FP8
153-
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type fp8 --fwd_bwd
153+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type fp8 --fwd_bwd
154154
cudnn_fe:: Median (fwd, bwd) Execution Times: 21.334 ms (1649 TFLOPS), 56.373 ms (1560 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations)
155155
```
156156

157157
The cuDNN version used in the benchmark can be replaced by setting the `LD_LIBRARY_PATH` environment variable.
158158
```
159159
$ export LD_LIBRARY_PATH=<my_path_to_cuDNN_9.10.2>
160-
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 16384 --kv_seqlen 16384 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --is_causal --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --fwd_bwd --data_type fp8 --verbose
160+
$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 16384 --kv_seqlen 16384 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --fwd_bwd --data_type fp8 --verbose
161161
[INFO] cuDNN Backend Version: cudnn.backend_version() = 91002
162162
[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.12.0'
163163
[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06'

benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py

Lines changed: 53 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010
import torch
1111
from torch.nn.attention import SDPBackend, sdpa_kernel
12+
from torch.nn.attention.bias import causal_lower_right
1213
import os
1314
import numpy as np
1415
import functools
@@ -53,7 +54,6 @@
5354
type=int,
5455
help="Number of iterations to run the layer",
5556
)
56-
parser.add_argument("--is_causal", action="store_true", help="Is causal masking on")
5757
parser.add_argument("--verbose", action="store_true", help="Verbose output")
5858
parser.add_argument(
5959
"--fwd_bwd",
@@ -64,8 +64,8 @@
6464
"--attn_mask",
6565
default="no_mask",
6666
type=str,
67-
help="Attn mask to use. Can be 'padding_causal' or 'no_mask'. If padding_causal, is_causal must be set to false. Only works for cuDNN FE or PyTorch backends.",
68-
choices=["padding_causal", "no_mask"],
67+
help="Attn mask to use. Can be 'top_left', 'bottom_right', or 'no_mask'.",
68+
choices=["top_left", "bottom_right", "no_mask"],
6969
)
7070
parser.add_argument(
7171
"--sdpa_backend",
@@ -111,12 +111,6 @@
111111
f"FP8 is only supported for cudnn_fe and flash_attention_3 backends"
112112
)
113113

114-
if args.attn_mask == "padding_causal":
115-
assert not args.is_causal, "Padding causal attn mask requires is_causal to be false"
116-
assert (
117-
args.q_seqlen <= args.kv_seqlen
118-
), "Padding causal attn mask requires q_seqlen <= kv_seqlen"
119-
120114
# Parse input arguments
121115
num_iters = args.num_iterations
122116
batch_size = args.batch_size
@@ -125,30 +119,16 @@
125119
num_q_heads = args.num_q_heads
126120
num_kv_heads = args.num_kv_heads
127121
head_dim = args.head_dim
128-
is_causal = args.is_causal
129122
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130123
assert device.type == "cuda", "Requires CUDA device"
131-
132124
enable_gqa = num_q_heads != num_kv_heads
125+
assert args.attn_mask != "bottom_right" or q_seqlen <= kv_seqlen, "Bottom right causal mask not supported when q_seqlen > kv_seqlen"
126+
127+
if args.sdpa_backend in ["flash_attention", "flash_attention_3", "pyt_flash_attention"]:
128+
assert args.attn_mask != "top_left", "Flash Attention does not support top left causal mask"
133129

134130
#############################################################
135131
########### Set up SDPA function for each backend ###########
136-
## Define various SDPA functions for each backend
137-
if args.attn_mask == "padding_causal":
138-
# Mask construction: rectangular tensor + triangular tensor
139-
rect_mask = torch.ones(
140-
q_seqlen, (kv_seqlen - q_seqlen), dtype=torch.bool, device=device
141-
)
142-
tri_mask = torch.tril(
143-
torch.ones(q_seqlen, q_seqlen, dtype=torch.bool, device=device)
144-
)
145-
146-
attn_mask = torch.cat(
147-
[rect_mask, tri_mask], dim=1
148-
) # .unsqueeze(0).repeat(batch_size, 1, 1)
149-
padding_fraction = attn_mask.sum() / attn_mask.numel()
150-
else:
151-
padding_fraction = 0.0
152132

153133
## If using cuDNN FE, set up cuDNN graph.
154134
if args.sdpa_backend == "cudnn_fe":
@@ -318,8 +298,8 @@ def convert_to_cudnn_type(torch_type):
318298
# generate_stats=not is_infer,
319299
is_inference=is_infer,
320300
attn_scale=attn_scale,
321-
use_causal_mask=is_causal,
322-
use_padding_mask=False,
301+
diagonal_alignment=cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask =="bottom_right" else cudnn.diagonal_alignment.TOP_LEFT,
302+
diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0,
323303
# dropout=dropout_tuple if is_dropout else None,
324304
)
325305
else:
@@ -333,8 +313,8 @@ def convert_to_cudnn_type(torch_type):
333313
# generate_stats=not is_infer,
334314
is_inference=is_infer,
335315
attn_scale=attn_scale,
336-
use_causal_mask=is_causal,
337-
use_causal_mask_bottom_right=args.attn_mask == "padding_causal",
316+
diagonal_alignment=cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask =="bottom_right" else cudnn.diagonal_alignment.TOP_LEFT,
317+
diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0,
338318
dropout=dropout_tuple if is_dropout else None,
339319
)
340320

@@ -378,7 +358,7 @@ def convert_to_cudnn_type(torch_type):
378358
).set_data_type(cudnn.data_type.FLOAT)
379359
graph_fwd.validate()
380360
graph_fwd.build_operation_graph()
381-
graph_fwd.create_execution_plans([cudnn.heur_mode.A])
361+
graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
382362
graph_fwd.check_support()
383363
graph_fwd.build_plans()
384364

@@ -452,8 +432,8 @@ def convert_to_cudnn_type(torch_type):
452432
scale_dV=scale_dV_bwd,
453433
scale_dP=scale_dP_bwd,
454434
attn_scale=attn_scale,
455-
use_causal_mask=is_causal,
456-
use_causal_mask_bottom_right=args.attn_mask == "padding_causal",
435+
diagonal_alignment=cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask =="bottom_right" else cudnn.diagonal_alignment.TOP_LEFT,
436+
diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0,
457437
dropout=dropout_tuple if is_dropout else None,
458438
)
459439
else:
@@ -471,8 +451,8 @@ def convert_to_cudnn_type(torch_type):
471451
dO=dO_bwd,
472452
stats=stats_bwd,
473453
attn_scale=attn_scale,
474-
use_causal_mask=is_causal,
475-
use_causal_mask_bottom_right=args.attn_mask == "padding_causal",
454+
diagonal_alignment=cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask =="bottom_right" else cudnn.diagonal_alignment.TOP_LEFT,
455+
diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0,
476456
dropout=dropout_tuple if is_dropout else None,
477457
)
478458

@@ -505,7 +485,7 @@ def convert_to_cudnn_type(torch_type):
505485

506486
graph_bwd.validate()
507487
graph_bwd.build_operation_graph()
508-
graph_bwd.create_execution_plans([cudnn.heur_mode.A])
488+
graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
509489
graph_bwd.check_support()
510490
graph_bwd.build_plans()
511491

@@ -620,41 +600,17 @@ def convert_to_cudnn_type(torch_type):
620600
variant_pack_bwd[offset_bwd] = dropout_offset
621601
## Done setting up cuDNN graph.
622602

623-
624-
# Reference implementation for output check
625-
def pyt_reference_sdpa(query, key, value):
626-
if args.attn_mask == "padding_causal":
603+
# For backends MATH, EFFICIENT_ATTENTION, CUDNN_ATTENTION, PYTORCH_FLASH_ATTENTION
604+
def pyt_backend_sdpa(query, key, value, backend):
605+
with sdpa_kernel(backends=[backend]):
627606
return torch.nn.functional.scaled_dot_product_attention(
628607
query,
629608
key,
630609
value,
631610
enable_gqa=enable_gqa,
632-
is_causal=is_causal,
633-
attn_mask=attn_mask,
611+
is_causal=args.attn_mask == "top_left",
612+
attn_mask=causal_lower_right(q_seqlen, kv_seqlen) if args.attn_mask == "bottom_right" else None,
634613
)
635-
else:
636-
return torch.nn.functional.scaled_dot_product_attention(
637-
query, key, value, enable_gqa=enable_gqa, is_causal=is_causal
638-
)
639-
640-
641-
# For backends MATH, EFFICIENT_ATTENTION, CUDNN_ATTENTION, FLASH_ATTENTION
642-
def pyt_backend_sdpa(query, key, value, backend):
643-
if args.attn_mask == "padding_causal":
644-
with sdpa_kernel(backends=[backend]):
645-
return torch.nn.functional.scaled_dot_product_attention(
646-
query,
647-
key,
648-
value,
649-
enable_gqa=enable_gqa,
650-
is_causal=is_causal,
651-
attn_mask=attn_mask,
652-
)
653-
else:
654-
with sdpa_kernel(backends=[backend]):
655-
return torch.nn.functional.scaled_dot_product_attention(
656-
query, key, value, enable_gqa=enable_gqa, is_causal=is_causal
657-
)
658614

659615

660616
if args.sdpa_backend == "flash_attention":
@@ -663,15 +619,15 @@ def pyt_backend_sdpa(query, key, value, backend):
663619

664620
# Flash Attention Native
665621
def flash_attention_sdpa(query, key, value):
666-
return flash_attn_func(query, key, value, causal=is_causal)
622+
return flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask")
667623

668624

669625
if args.sdpa_backend == "flash_attention_3":
670626
import flash_attn_interface
671627

672628
def flash_attention_3_sdpa(query, key, value):
673629
output, _ = flash_attn_interface.flash_attn_func(
674-
query, key, value, causal=is_causal
630+
query, key, value, causal=args.attn_mask != "no_mask"
675631
)
676632
return output
677633

@@ -731,24 +687,26 @@ def flops(
731687
kv_seqlen,
732688
head_dim,
733689
num_q_heads,
734-
num_kv_heads,
735-
causal,
690+
attn_mask,
736691
mode="fwd",
737-
padding_fraction=0.0,
738692
):
739693
assert mode in ["fwd", "bwd", "fwd_bwd"]
740-
f = (
741-
4
742-
* batch_size
743-
* q_seqlen
744-
* kv_seqlen
745-
* num_q_heads
746-
* head_dim
747-
// (2 if causal else 1)
748-
)
749-
if padding_fraction > 0.0:
750-
f = f * padding_fraction
751-
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
694+
695+
if attn_mask == "no_mask":
696+
num_nonmasked_elems = q_seqlen * kv_seqlen
697+
elif attn_mask == "top_left":
698+
num_nonmasked_elems = torch.tril(torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool)).sum()
699+
elif attn_mask == "bottom_right":
700+
diagonal_offset = kv_seqlen - q_seqlen
701+
num_nonmasked_elems = torch.tril(
702+
torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool),
703+
diagonal=diagonal_offset,
704+
).sum()
705+
attention_weights = batch_size * num_q_heads * head_dim * num_nonmasked_elems
706+
flops_per_gemm = attention_weights * 2
707+
708+
result = flops_per_gemm * 2 if mode == "fwd" else (5 * flops_per_gemm if mode == "bwd" else 7 * flops_per_gemm)
709+
return result
752710

753711

754712
def tflops_per_sec(
@@ -757,11 +715,9 @@ def tflops_per_sec(
757715
kv_seqlen,
758716
head_dim,
759717
num_q_heads,
760-
num_kv_heads,
761-
causal,
718+
attn_mask,
762719
time,
763720
mode="fwd",
764-
padding_fraction=0.0,
765721
):
766722
assert mode in ["fwd", "bwd", "fwd_bwd"]
767723
f = flops(
@@ -770,10 +726,8 @@ def tflops_per_sec(
770726
kv_seqlen,
771727
head_dim,
772728
num_q_heads,
773-
num_kv_heads,
774-
causal,
729+
attn_mask,
775730
mode,
776-
padding_fraction,
777731
)
778732
return f / time / 1e9 if not math.isnan(time) else 0.0 # Assume time is in msec
779733

@@ -1078,7 +1032,14 @@ def tflops_per_sec(
10781032
)
10791033
if args.data_type != "fp8":
10801034
try:
1081-
output_ref = pyt_reference_sdpa(query, key, value)
1035+
output_ref = torch.nn.functional.scaled_dot_product_attention(
1036+
query,
1037+
key,
1038+
value,
1039+
enable_gqa=enable_gqa,
1040+
is_causal=args.attn_mask == "top_left",
1041+
attn_mask=causal_lower_right(q_seqlen, kv_seqlen) if args.attn_mask == "bottom_right" else None,
1042+
)
10821043
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
10831044
forward_diffs.append(
10841045
torch.max(torch.abs(output.detach() - output_ref.detach())).item()
@@ -1109,11 +1070,9 @@ def tflops_per_sec(
11091070
args.kv_seqlen,
11101071
args.head_dim,
11111072
args.num_q_heads,
1112-
args.num_kv_heads,
1113-
args.is_causal,
1073+
args.attn_mask,
11141074
fwd_median_time,
11151075
"fwd",
1116-
padding_fraction,
11171076
)
11181077
if args.fwd_bwd:
11191078
bwd_median_time = np.median(np.array(backward_times[5:]))
@@ -1123,11 +1082,9 @@ def tflops_per_sec(
11231082
args.kv_seqlen,
11241083
args.head_dim,
11251084
args.num_q_heads,
1126-
args.num_kv_heads,
1127-
args.is_causal,
1085+
args.attn_mask,
11281086
bwd_median_time,
11291087
"bwd",
1130-
padding_fraction,
11311088
)
11321089
if args.format_output:
11331090
print(

dlpack_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.1
1+
1.1

include/cudnn_frontend/cudnn_interface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ create_cudnn_tensor(
5858
tensor_builder.setDim(props->get_dim().size(), props->get_dim().data())
5959
.setStrides(props->get_stride().size(), props->get_stride().data())
6060
.setId(tensor_uid)
61-
.setAlignment(16)
61+
.setAlignment(props->get_alignment())
6262
.setDataType(props->get_data_type())
6363
.setVirtual(props->get_is_virtual())
6464
.setByValue(props->get_is_pass_by_value())

0 commit comments

Comments
 (0)