Commit f0530d2
authored
[TritonNVIDIAGPU] Add dependency tokens to TMEM ops (#6520)
The Triton middle-end has perfect dependency+modref information about
TMEM (and shared memory) because it is introduced by the middle-end by
expanding chains of SSA ops. E.g. `HoistTMEMAlloc` is essentially a form
of reg-2-mem for MMA accumulators.
Despite this, the dependency and alias analysis needed by
`HoistTMEMAlloc`, warp specialization, and the pipeliner rely on ad-hoc
checks that are not always correct and which are becoming increasingly
complex. Instead of building stronger memory analysis, we can just not
discard the information the compiler already has.
This PR adds tokens to all the ops that touch TMEM (except `TMEMCopyOp`,
since it is not used in the middle-end), and acts as a form of MemorySSA
(memory variable lattice encoded in the IR), and leverages them
throughout the middle-end to check aliasing, modref, etc. information
instead of scanning the IR. Consequently, the transformations are more
robust and easier to maintain, at the cost of extra book-keeping that is
necessary.
This will greatly simplify the dependence analysis needed by more
complex warp specialization, and help with composing warp specialization
with the pipeliner(cc @htyu @manman-ren).
There would be a pretty big performance cliff if this PR was wrong
(failed to pipeline/warp specialize), so I sanity checked that it did
not break pipelining.
### Performance numbers after
```
├─ 703.378 976.992 matmul_kernel [M=8192, N=8192, K=512]
├─ 936.461 733.821 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 938.393 732.310 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.351 802.468 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 785.072 875.327 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1024.165 670.981 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1125.056 610.810 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 800.940 857.986 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
```
```
fused-attention-batch4-head32-d64-fwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 183.032906 176.540661
1 2048.0 384.363999 417.633483
2 4096.0 471.816004 511.814693
3 8192.0 519.752669 566.761880
4 16384.0 545.707761 595.042579
fused-attention-batch4-head32-d64-fwd-causal=False:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 364.631059 364.685641
1 2048.0 492.108137 536.102664
2 4096.0 532.795804 580.166599
3 8192.0 550.670842 599.591255
4 16384.0 559.480705 608.551411
fused-attention-batch4-head32-d64-bwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 144.731066 152.721176
1 2048.0 234.101200 234.195236
2 4096.0 293.602665 293.519568
3 8192.0 331.644550 331.388321
4 16384.0 355.252999 354.861517
```
```
Problem Shape = 8192x8192x512
└─ 974.209 705.458 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]
```
### Performance numbers before
```
├─ 708.163 970.391 matmul_kernel [M=8192, N=8192, K=512]
├─ 935.792 734.346 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 922.666 744.793 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.643 802.195 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 792.424 867.206 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1020.997 673.063 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1134.083 605.948 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 799.650 859.369 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
```
```
fused-attention-batch4-head32-d64-fwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 181.507652 183.077756
1 2048.0 384.836411 416.908797
2 4096.0 471.260742 512.709282
3 8192.0 519.896730 566.172554
4 16384.0 545.181917 595.246382
fused-attention-batch4-head32-d64-fwd-causal=False:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 368.266771 373.516950
1 2048.0 492.137719 535.968650
2 4096.0 533.092876 580.134559
3 8192.0 550.571575 599.455669
4 16384.0 559.555689 608.442981
fused-attention-batch4-head32-d64-bwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 151.081525 155.745186
1 2048.0 234.359406 234.108984
2 4096.0 293.584945 293.689437
3 8192.0 331.633380 331.669234
4 16384.0 355.077635 354.963313
```
```
Problem Shape = 8192x8192x512
└─ 972.794 706.484 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]
```1 parent 3932686 commit f0530d2
31 files changed
Lines changed: 954 additions & 745 deletions
File tree
- include/triton/Dialect
- TritonGPU/Transforms
- TritonNvidiaGPU
- IR
- Transforms
- lib/Dialect
- TritonGPU/Transforms
- Pipeliner
- WarpSpecialization
- TritonNvidiaGPU
- IR
- Transforms
- python
- test/unit/language
- tutorials
- test/TritonGPU
- third_party/nvidia
- backend
Lines changed: 0 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | 20 | | |
26 | 21 | | |
27 | 22 | | |
| |||
Lines changed: 9 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
51 | 60 | | |
52 | 61 | | |
53 | 62 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
243 | 243 | | |
244 | 244 | | |
245 | 245 | | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | 246 | | |
251 | 247 | | |
252 | 248 | | |
| |||
Lines changed: 9 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
43 | 52 | | |
44 | 53 | | |
45 | 54 | | |
Lines changed: 68 additions & 18 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
417 | 417 | | |
418 | 418 | | |
419 | 419 | | |
420 | | - | |
| 420 | + | |
421 | 421 | | |
422 | 422 | | |
423 | 423 | | |
| |||
427 | 427 | | |
428 | 428 | | |
429 | 429 | | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
430 | 434 | | |
431 | 435 | | |
432 | 436 | | |
433 | 437 | | |
434 | 438 | | |
435 | 439 | | |
| 440 | + | |
436 | 441 | | |
437 | 442 | | |
438 | 443 | | |
439 | 444 | | |
440 | 445 | | |
441 | 446 | | |
| 447 | + | |
442 | 448 | | |
443 | 449 | | |
444 | | - | |
445 | | - | |
446 | | - | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
447 | 454 | | |
448 | 455 | | |
449 | 456 | | |
450 | 457 | | |
451 | | - | |
452 | | - | |
| 458 | + | |
| 459 | + | |
453 | 460 | | |
454 | 461 | | |
455 | 462 | | |
| |||
459 | 466 | | |
460 | 467 | | |
461 | 468 | | |
462 | | - | |
| 469 | + | |
463 | 470 | | |
464 | 471 | | |
465 | 472 | | |
466 | 473 | | |
467 | 474 | | |
468 | 475 | | |
469 | 476 | | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
470 | 481 | | |
471 | 482 | | |
472 | 483 | | |
473 | 484 | | |
474 | 485 | | |
475 | 486 | | |
| 487 | + | |
476 | 488 | | |
477 | 489 | | |
478 | 490 | | |
| |||
482 | 494 | | |
483 | 495 | | |
484 | 496 | | |
| 497 | + | |
| 498 | + | |
485 | 499 | | |
486 | 500 | | |
487 | 501 | | |
| |||
491 | 505 | | |
492 | 506 | | |
493 | 507 | | |
494 | | - | |
| 508 | + | |
495 | 509 | | |
496 | | - | |
497 | | - | |
| 510 | + | |
| 511 | + | |
498 | 512 | | |
499 | 513 | | |
500 | 514 | | |
501 | 515 | | |
502 | 516 | | |
503 | 517 | | |
504 | 518 | | |
505 | | - | |
506 | | - | |
| 519 | + | |
| 520 | + | |
507 | 521 | | |
508 | 522 | | |
509 | 523 | | |
| |||
517 | 531 | | |
518 | 532 | | |
519 | 533 | | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
520 | 551 | | |
521 | | - | |
522 | 552 | | |
523 | | - | |
524 | | - | |
525 | 553 | | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
526 | 559 | | |
527 | 560 | | |
528 | 561 | | |
529 | 562 | | |
530 | 563 | | |
531 | 564 | | |
532 | | - | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
533 | 570 | | |
534 | 571 | | |
535 | 572 | | |
| 573 | + | |
536 | 574 | | |
537 | 575 | | |
538 | 576 | | |
| 577 | + | |
539 | 578 | | |
540 | | - | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
541 | 583 | | |
542 | 584 | | |
543 | 585 | | |
| |||
551 | 593 | | |
552 | 594 | | |
553 | 595 | | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
554 | 600 | | |
555 | 601 | | |
556 | 602 | | |
557 | 603 | | |
558 | 604 | | |
559 | | - | |
560 | 605 | | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
561 | 611 | | |
562 | 612 | | |
563 | 613 | | |
| |||
Lines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
| 61 | + | |
| 62 | + | |
61 | 63 | | |
62 | 64 | | |
63 | 65 | | |
| |||
Lines changed: 9 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
142 | 142 | | |
143 | 143 | | |
144 | 144 | | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
145 | 154 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
544 | 544 | | |
545 | 545 | | |
546 | 546 | | |
| 547 | + | |
547 | 548 | | |
548 | | - | |
| 549 | + | |
549 | 550 | | |
550 | 551 | | |
551 | | - | |
| 552 | + | |
| 553 | + | |
552 | 554 | | |
553 | 555 | | |
554 | | - | |
555 | | - | |
| 556 | + | |
| 557 | + | |
556 | 558 | | |
557 | 559 | | |
558 | 560 | | |
| |||
697 | 699 | | |
698 | 700 | | |
699 | 701 | | |
| 702 | + | |
700 | 703 | | |
701 | | - | |
| 704 | + | |
702 | 705 | | |
703 | 706 | | |
704 | 707 | | |
| |||
728 | 731 | | |
729 | 732 | | |
730 | 733 | | |
731 | | - | |
732 | | - | |
733 | | - | |
734 | | - | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
735 | 742 | | |
736 | | - | |
737 | | - | |
738 | | - | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
739 | 747 | | |
740 | | - | |
741 | | - | |
| 748 | + | |
| 749 | + | |
742 | 750 | | |
743 | 751 | | |
744 | 752 | | |
| |||
0 commit comments