Commit 1817f4a
authored
[Reland] Attention(23) CUDA (#27030)
## Reland reason
Reland #26466
The previous PR was reverted because it fails on the test:
1. Windows GPU CUDA CI Pipeline Test Job
2. Windows GPU TensorRT CI Pipeline Test Job
This PR includes the correct
[fix](cc7a947).
---
## Description
This pull request introduces significant improvements and expanded
support for multi-head attention kernels in ONNX Runtime, particularly
focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input
formats. The changes enhance flexibility, correctness, and
maintainability for attention operations across CPU and CUDA
implementations.
### Expanded QKV Input Format Support
* Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention
kernels, including proper handling for both cases with and without
past/present states, and enforcing that bias is not supported for this
format. This includes logic to avoid unnecessary transposes and to write
outputs directly when possible.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388)
[[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435)
[[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
[[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748)
[[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
### Kernel and Operator Documentation Updates
* Updated `OperatorKernels.md` to document the new `Attention` operator
inputs and outputs for both 3D and 4D formats, specifying supported
tensor types for each input.
### Correctness and Consistency Fixes
* Fixed the computation of causal attention indices in CUDA softmax
kernels by clarifying and correcting the offset calculation for causal
masking.
[[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168)
[[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244)
[[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336)
[[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442)
* Updated workspace allocation logic for QKV preparation to ensure
correct workspace usage for new formats.
### Attention Parameter and Helper Refactoring
* Added `is_output_bnsh` field to `AttentionParameters` to indicate
output format and updated logic to use this for output placement and
transposition decisions.
[[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37)
[[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791)
* Refactored CPU attention implementation to use the new
`attention_helper` namespace for output mode enums and output shape
computation, improving code clarity and maintainability.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149)
### Minor Cleanups
* Removed outdated asserts and improved debug output strings for QKV
preparation functions to clarify format and state handling.
[[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254)
[[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363)
[[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716)
These changes collectively improve the flexibility, correctness, and
maintainability of attention kernel implementations in ONNX Runtime,
especially for advanced transformer models and large language model
workloads.
**NOT supported in this PR**
- Boolean mask
- GQA
- Softcap
- Softmax precision
- qk_output_mode other than -1 and 0
- **is_causal=True && q_sequence_kength != kv_sequence_length**1 parent 36017ad commit 1817f4a
File tree
15 files changed
+745
-396
lines changed- docs
- onnxruntime
- contrib_ops
- cpu/bert
- cuda/bert
- core/providers
- cpu/llm
- cuda
- llm
- test
- providers/cpu/llm
- testdata
15 files changed
+745
-396
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
653 | 653 | | |
654 | 654 | | |
655 | 655 | | |
| 656 | + | |
656 | 657 | | |
657 | 658 | | |
658 | 659 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
772 | 772 | | |
773 | 773 | | |
774 | 774 | | |
775 | | - | |
776 | | - | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
777 | 778 | | |
778 | 779 | | |
779 | 780 | | |
780 | 781 | | |
781 | 782 | | |
782 | 783 | | |
783 | 784 | | |
784 | | - | |
785 | | - | |
786 | | - | |
| 785 | + | |
| 786 | + | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
787 | 790 | | |
788 | | - | |
| 791 | + | |
789 | 792 | | |
790 | 793 | | |
791 | 794 | | |
| |||
960 | 963 | | |
961 | 964 | | |
962 | 965 | | |
963 | | - | |
| 966 | + | |
964 | 967 | | |
965 | 968 | | |
966 | 969 | | |
| |||
981 | 984 | | |
982 | 985 | | |
983 | 986 | | |
984 | | - | |
| 987 | + | |
985 | 988 | | |
986 | 989 | | |
987 | 990 | | |
988 | 991 | | |
989 | | - | |
| 992 | + | |
990 | 993 | | |
991 | 994 | | |
992 | 995 | | |
| |||
Lines changed: 144 additions & 106 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
254 | | - | |
255 | 254 | | |
256 | 255 | | |
257 | 256 | | |
| |||
262 | 261 | | |
263 | 262 | | |
264 | 263 | | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
296 | | - | |
297 | | - | |
298 | | - | |
299 | | - | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
300 | 301 | | |
301 | | - | |
302 | | - | |
303 | | - | |
304 | | - | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
305 | 306 | | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
314 | 315 | | |
315 | | - | |
316 | | - | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
317 | 345 | | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
341 | 354 | | |
342 | 355 | | |
343 | 356 | | |
| |||
360 | 373 | | |
361 | 374 | | |
362 | 375 | | |
363 | | - | |
364 | 376 | | |
365 | 377 | | |
366 | 378 | | |
| |||
373 | 385 | | |
374 | 386 | | |
375 | 387 | | |
376 | | - | |
377 | | - | |
378 | | - | |
379 | | - | |
380 | | - | |
381 | | - | |
382 | | - | |
383 | 388 | | |
384 | 389 | | |
385 | 390 | | |
386 | 391 | | |
387 | 392 | | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
388 | 408 | | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
397 | | - | |
398 | | - | |
399 | | - | |
400 | | - | |
401 | | - | |
402 | | - | |
403 | | - | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
404 | 428 | | |
405 | | - | |
406 | | - | |
407 | | - | |
408 | | - | |
409 | | - | |
410 | | - | |
411 | | - | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
412 | 435 | | |
413 | 436 | | |
414 | 437 | | |
| |||
670 | 693 | | |
671 | 694 | | |
672 | 695 | | |
673 | | - | |
| 696 | + | |
674 | 697 | | |
675 | 698 | | |
676 | | - | |
| 699 | + | |
677 | 700 | | |
678 | 701 | | |
679 | 702 | | |
680 | | - | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
681 | 717 | | |
682 | 718 | | |
683 | 719 | | |
| |||
708 | 744 | | |
709 | 745 | | |
710 | 746 | | |
| 747 | + | |
| 748 | + | |
711 | 749 | | |
712 | 750 | | |
713 | 751 | | |
| |||
0 commit comments