You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This pull request introduces significant improvements and expanded
support for multi-head attention kernels in ONNX Runtime, particularly
focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input
formats. The changes enhance flexibility, correctness, and
maintainability for attention operations across CPU and CUDA
implementations.
### Expanded QKV Input Format Support
* Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention
kernels, including proper handling for both cases with and without
past/present states, and enforcing that bias is not supported for this
format. This includes logic to avoid unnecessary transposes and to write
outputs directly when possible.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388)
[[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435)
[[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
[[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748)
[[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
### Kernel and Operator Documentation Updates
* Updated `OperatorKernels.md` to document the new `Attention` operator
inputs and outputs for both 3D and 4D formats, specifying supported
tensor types for each input.
### Correctness and Consistency Fixes
* Fixed the computation of causal attention indices in CUDA softmax
kernels by clarifying and correcting the offset calculation for causal
masking.
[[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168)
[[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244)
[[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336)
[[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442)
* Updated workspace allocation logic for QKV preparation to ensure
correct workspace usage for new formats.
### Attention Parameter and Helper Refactoring
* Added `is_output_bnsh` field to `AttentionParameters` to indicate
output format and updated logic to use this for output placement and
transposition decisions.
[[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37)
[[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
* Refactored CPU attention implementation to use the new
`attention_helper` namespace for output mode enums and output shape
computation, improving code clarity and maintainability.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149)
### Minor Cleanups
* Removed outdated asserts and improved debug output strings for QKV
preparation functions to clarify format and state handling.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
These changes collectively improve the flexibility, correctness, and
maintainability of attention kernel implementations in ONNX Runtime,
especially for advanced transformer models and large language model
workloads.
**NOT supported in this PR**
- Boolean mask
- GQA
- Softcap
- Softmax precision
- qk_output_mode other than -1 and 0
0 commit comments