Skip to content

[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693

Open
KshitijLakhani wants to merge 15 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-thd-flash-support
Open

[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693
KshitijLakhani wants to merge 15 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-thd-flash-support

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 20, 2026

Description

Enable sm120 support for THD for fused attn for cuDNN 9.18.1+

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • For sm120, change the shape for stats tensors to be BHS1 instead of TH1. Propogate changes for the same.
  • For sm120, disable fused attention in get_attention_backends() if T3HD or TH3D shapes are used as cuDNN does not support then. Also, assert in common for the same before calling f16 arbitrary seqlens fwd/bwd
  • For sm120, disable fused and flash attention for kv cache in get_attention_backends()(until fully supported)
  • NOTE: No changes made to test code (skip for sm120 etc.) - any skips to be achieved via disabling of backend attn type rather than hard hammer way of disabling tests

Test results:

Ran PyT attention tests on sm120 and no failures:

klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_attention_with_cp.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 8488 items                                                                                                                                                                                                           
Running 8488 items in this shard

tests/pytorch/attention/test_attention_with_cp.py .s.ss.s.ss...sssssss...ss...ss.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  1%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  4%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  7%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  9%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 12%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 14%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 17%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 19%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 22%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 24%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 27%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 29%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 32%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 35%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 37%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 40%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 42%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 45%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 47%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 52%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 57%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 60%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 63%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.s.sss.s.s.s.s.s.sss.s.sssss.sssssssssssss.s.sss.s.sssssssssssssssss [ 65%]
ssssssssssssssssss.s.sss.s.sssssssssss.s.s.sss.s.sssssssssss.s.s.sssss.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 68%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 85%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 88%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 98%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                       [100%]

======================================================================================================= warnings summary =======================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================================================== 44 passed, 8444 skipped, 3 warnings in 576.35s (0:09:36) ===================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_attention.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 2607 items                                                                                                                                                                                                           
Running 2607 items in this shard

tests/pytorch/attention/test_attention.py ......................................................................ss............................................................................s.....s........ssss....... [  6%]
.ssss.s.................................................ss..sss...ss..sss....s...s.....s...s.....s...s....ss..sss...ss..sss....s...s.....s...s.....s...s...ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.......................... [ 14%]
................................ss..................ss..............ssssss..ssss....sssss.s....sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 23%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 31%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 39%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 48%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 56%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 64%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 81%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                                [100%]

======================================================================================================= warnings summary =======================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=padding_causal
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal_bottom_right
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=padding_causal_bottom_right
    warnings.warn(

tests/pytorch/attention/test_attention.py::test_dot_product_attention[False-False-None-True-False-base_1_0-model_configs0-dtype0]
  /usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py:869: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:335.)
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

tests/pytorch/attention/test_attention.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================== 396 passed, 2211 skipped, 21 warnings in 134.70s (0:02:14) ==================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_cp_utils.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 9 items                                                                                                                                                                                                                                 
Running 9 items in this shard

tests/pytorch/attention/test_cp_utils.py .........                                                                                                                                                                                          [100%]

================================================================================================================ warnings summary =================================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================== 9 passed, 2 warnings in 3.93s ==========================================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_kv_cache.py 
========================================================================================================================== test session starts ===========================================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 576 items                                                                                                                                                                                                                                                      
Running 576 items in this shard

tests/pytorch/attention/test_kv_cache.py ssssssssssssssssssssssssssssssssssssssssssssssss........................ssssssssssssssssssssssssssssssssssssssssssssssss........................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 37%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                              [100%]

============================================================================================================================ warnings summary ============================================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/pytorch/attention/test_kv_cache.py: 576 warnings
  /home/klakhani/TE/tests/pytorch/attention/test_kv_cache.py:86: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).
    self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")

tests/pytorch/attention/test_kv_cache.py: 288 warnings
  /home/klakhani/TE/transformer_engine/pytorch/attention/inference.py:435: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).
    self.batch_indices_post_step = torch.range(

tests/pytorch/attention/test_kv_cache.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================== 48 passed, 528 skipped, 880 warnings in 62.17s (0:01:02) ========================================================================================================

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-thd-flash-support branch from 674394b to 998b3b8 Compare February 20, 2026 18:40
KshitijLakhani and others added 3 commits March 2, 2026 15:31
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pe instead of TH1 for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-thd-flash-support branch from dc282ea to b2f5864 Compare March 2, 2026 23:31
pre-commit-ci bot and others added 9 commits March 2, 2026 23:32
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…incorrect max logit calculation (includes padded tokens in max calculation)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pa arbitrary kernel call

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…clude a check for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani changed the title Enable sm120 support for fused attn if cuDNN is 9.18.1+ [PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+ Mar 11, 2026
@KshitijLakhani KshitijLakhani self-assigned this Mar 11, 2026
@KshitijLakhani KshitijLakhani marked this pull request as ready for review March 12, 2026 21:51
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This PR enables SM120 (Blackwell) support for fused attention with THD (ragged sequence) layouts when cuDNN ≥ 9.18.1, by adapting the graph-building logic for a cuDNN stride validation difference on that architecture.

Key changes:

  • Stats tensor layout: On SM120, cuDNN rejects THD-style "packed" strides (stride[0] > dim[1]×dim[2]×dim[3]) in its support check. The fix uses BHSD-like dimensions with max_seqlen at plan-build time, keeping ragged offsets for variable-length boundaries. This changes the stats tensor from the compact [T, H, 1] shape to the padded [B, H, S, 1] shape.
  • Python-side stats masking: fused_attn_fwd now masks padded positions in the 4D stats tensor (using cu_seqlens_q deltas) before computing max_logit, and corrects amax_dims to always yield shape [h]. This also inadvertently fixes a latent dimension bug for pre-9.6.0 cuDNN + THD + return_max_logit.
  • Layout restrictions: T3HD and TH3D QKV layouts remain blocked on SM120 (both in the Python backend selector and as a runtime assert in C++). KV-cache inference paths additionally disable both fused and flash attention on SM120 until fully verified.
  • use_ragged_stats flag: Consolidates the repeated is_ragged_q && cudnn_runtime_version >= 90600 condition into a single boolean that additionally gates on sm_arch_ < 120, simplifying the many call-sites within fwd_impl and bwd_impl.
  • Context parallelism: softmax_lse_in_packed_format is forced to False on SM120 so the CP ring-buffer accumulation treats the LSE in batched (not packed) format, matching the BHS1 output from cuDNN.

Confidence Score: 4/5

  • Safe to merge; changes are well-scoped to SM120 paths and all non-SM120 behavior is preserved.
  • The SM120-specific logic is internally consistent: fwd/bwd both skip the max_b/max_t_q substitution, both set use_ragged_stats = false, and the resulting BHS1 stats shape is correctly consumed by both the C++ backward and the Python masking code. The use_flash_attention = False gate for SM120+KV-cache is correctly propagated through the and-gate at lines 1108-1109 before being recombined. The only item is the redundant device queries (up to 3× per forward call) introduced across the call stack, which is a minor maintenance concern but not a correctness issue. Tests ran on SM120 hardware with no failures.
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py — the new seqlens_q masking path is exercised only for THD+BHS1 stats (SM120 or pre-9.6.0 cuDNN with return_max_logit=True); worth verifying with a CP + SM120 end-to-end training test once KV-cache restrictions are lifted.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds SM120 guard checks before calling fused_attn_arbitrary_seqlen_fwd/bwd for T3HD/TH3D layouts. Forward message is missing a period (noted in prior review). Both forward and backward do independent device queries where the forward already queries inside _impl.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Introduces use_ragged_stats flag consolidating repeated conditions; skips the max_b/max_t_q dimension substitution on SM120, enabling BHSD-like layout to pass cuDNN's stride check; updates stats tensor shapes and ragged-offset packing paths accordingly.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Correctly gates softmax_lse_in_packed_format for SM120 (compute capability >= 12.0): even with cuDNN >= 9.6, SM120 produces BHS1-shaped stats, so the packed-format flag must remain False.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Enables fused attention on SM120 with cuDNN >= 9.18.1 (except t3hd/th3d); disables fused and flash attention for KV cache on SM120. The use_flash_attention = False at line 565 is correctly propagated via the and gates at lines 1108-1109.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Handles 4D (BHS1) stats tensors for SM120/pre-9.6 cuDNN: masks padded positions before computing max_logit, and fixes the amax_dims to correctly reduce to shape [h] for both 3D and 4D stats tensors.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend - utils.py] --> B{SM120 + cuDNN < 9.18.1?}
    B -- Yes --> C[Disable FusedAttention for THD]
    B -- No --> D{SM120 + t3hd/th3d layout?}
    D -- Yes --> C
    D -- No --> E[FusedAttention enabled for SM120 + THD]

    E --> F[fused_attn_fwd - Python]
    F --> G{T3HD or TH3D layout?}
    G -- Yes --> H[NVTE_ERROR - assert in nvte_fused_attn_fwd]
    G -- No --> I[fused_attn_arbitrary_seqlen_fwd_impl]

    I --> J{SM120 and ragged Q/KV?}
    J -- Yes --> K[Keep b=batch, s_q=max_seqlen_q\nBHSD-like layout - passes cuDNN stride check]
    J -- No --> L[b=max_b, s_q=max_t_q\nPacked layout - quantization bucket]

    K --> M{use_ragged_stats?\nis_ragged_q AND cudnn>=9.6 AND sm<120}
    L --> M
    M -- False - SM120 --> N[Stats: BHS1 stride\nbatch x heads x max_seqlen_q x 1]
    M -- True - non-SM120 --> O[Stats: TH1 stride\nnum_tokens x heads x 1\nwith ragged offset]

    N --> P[return_max_logit path in Python\nmax_tensor.ndim==4: mask padded positions\namax over dims 0,2,3 to get shape h]
    O --> Q[return_max_logit path in Python\nmax_tensor.ndim==3: amax over dims 0,2 to get shape h]

    R[context_parallel.py\nAttnFuncWithCPAndKVP2P] --> S{SM120?}
    S -- Yes --> T[softmax_lse_in_packed_format=False\nUse BHS1 format]
    S -- No --> U[softmax_lse_in_packed_format = cudnn>=9.6\nUse TH1 packed format]
Loading

Last reviewed commit: bcfef90

Comment on lines +639 to +641
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing period in forward error message

The forward error message is missing a period that is present in the corresponding backward error message (line 748). Minor inconsistency but worth fixing for uniformity.

Suggested change
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. "
"Use thd_thd_thd or other THD layouts instead.");

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@KshitijLakhani
Copy link
Collaborator Author

/te-ci L0 L1

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Collaborator Author

/te-ci L0 L1

@KshitijLakhani KshitijLakhani requested a review from cyanguwa March 13, 2026 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant