Skip to content

Commit cbb02a6

Browse files
kvbp2kmeta-codesync[bot]
authored andcommitted
[AutoWS] Add Swing Modulo Scheduling (SMS) as alternative to Rau's IMS (#1257)
Summary: Add modulo scheduling for automatic warp-specialization MMA annotation. Enables `TRITON_USE_MODULO_SCHEDULE=sms` to automatically derive pipeline stage assignments that match hand-tuned `attrs=` annotations on FA BWD. ## FA BWD performance (B200, TRITON_USE_META_WS=1 TRITON_USE_META_PARTITION=1) run example: `TRITON_USE_META_WS=1 TRITON_USE_META_PARTITION=1 TRITON_ALWAYS_COMPILE=1 TRITON_USE_MODULO_SCHEDULE=sms python -m pytest python/tutorials/fused-attention-ws-device-tma.py -k "bwd and 128 and 1024 and 16 and 8" -v` | Shape | Baseline TFLOPS | SMS TFLOPS | Diff | |---|---|---|---| | Z=4 H=16 N=2048 D=128 | 409.4 | 409.9 | +0.1% | | Z=8 H=16 N=1024 D=128 | 324.7 | 323.3 | -0.4% | | Z=1 H=32 N=4096 D=128 | 471.2 | 472.0 | +0.2% | ## What it does The modulo scheduling pass runs before the WS pass and sets `tt.autows` annotations on MMA ops. These annotations tell the downstream pipeliner which MMA ops should be grouped into pipeline stages for cross-iteration overlap. ### Scheduling algorithms Selected via `TRITON_USE_MODULO_SCHEDULE=<algo>`: | Value | Algorithm | Description | |-------|-----------|-------------| | `sms` | Swing Modulo Scheduling | Slack-based ordering, directional placement (Llosa et al., PACT 1996) | | `exhaustive` | Branch-and-bound | Explores all valid stage assignments with memory feasibility checks | | `random` | Random sampling | Dependency-aware random stage assignments | | `1` | Rau's IMS | Critical-path ordering with ejection backtracking (Rau, 1994) | ### Key design decisions **selfLatency = 1 for all GPU pipelines.** GPU execution units are deeply pipelined — a new instruction can be issued every ~1 cycle. Using completion latency (e.g., 900 cycles for MMA) as `selfLatency` inflated ResMII to 4500 for FA BWD (5 MMAs), causing all schedulers to fail. With `selfLatency=1`, RecMII (data dependencies) correctly drives the schedule. **Stage assignment via transitive MMA dependency counting.** After the scheduler assigns cycles, the pass derives pipeline stages: - 0-1 transitive MMA predecessors → stage 0 (prefetchable) - 2+ transitive MMA predecessors → stage 1 (gated on multiple prior results) This matches the hand-tuned FA BWD partition exactly: | MMA | Transitive MMA deps | Stage | Order | |-----|---------------------|-------|-------| | qkT = dot(k, qT) | 0 | 0 | 0 | | dpT = dot(v, do^T) | 0 | 0 | 0 | | dv += dot(ppT, do) | 1 (qkT) | 0 | 1 | | dq = dot(dsT^T, k) | 2 (qkT, dpT) | 1 | 0 | | dk += dot(dsT, qT) | 2 (qkT, dpT) | 1 | 0 | **Independent MMAs share the same order** within a stage to avoid barrier deadlocks. Annotations are skipped for loops with existing `tt.autows` from Python `attrs=` or when all MMAs land in the same stage. ## Files changed - `SwingScheduler.cpp/h` — SMS implementation - `ExhaustiveScheduler.cpp/h` — exhaustive + random search - `ModuloReservationTable.cpp/h` — Rau's IMS + dispatch logic - `LatencyModel.cpp` — selfLatency=1 fix, added 8 missing tensor op latencies - `ModuloSchedulePass.cpp` — dependency-based stage/order assignment, loop filtering - `DataDependenceGraph.cpp/h` — DDG construction - `compiler.py` — pass ordering, modulo pass before data partitioning - `knobs.py` — `TRITON_USE_MODULO_SCHEDULE` as `env_opt_str` - `GetEnv.hpp` — register env var for C++ access - `test_modulo_schedule.py` — E2E tests for all 3 algorithms - `ws_global_instruction_scheduling.md` — design doc with SMS details + benchmarks - `CMakeLists.txt` — SwingScheduler.cpp, ExhaustiveScheduler.cpp Pull Request resolved: #1257 Test Plan: - [x] `pytest python/test/unit/cuda/test_modulo_schedule.py` — all 3 algos pass - [x] FA BWD correctness: 12/12 tests passed with `TRITON_USE_MODULO_SCHEDULE=sms` - [x] FA BWD perf matches baseline (±0.5%) Authored with Claude. Reviewed By: htyu Differential Revision: D101116271 Pulled By: kvbp2k fbshipit-source-id: 345998592443c0279cd0efe8ce08613f64b2856a
1 parent 90fe2f4 commit cbb02a6

17 files changed

Lines changed: 1561 additions & 95 deletions

docs/design/ws_global_instruction_scheduling.md

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ This document is based on the original design in [WS global instruction scheduli
2020
- [Step 1: Compute Minimum Initiation Interval (II)](#step-1-compute-minimum-initiation-interval-ii)
2121
- [Step 2: Modulo Reservation Table Scheduling](#step-2-modulo-reservation-table-scheduling)
2222
- [Background: Rau's Iterative Modulo Scheduling](#background-raus-iterative-modulo-scheduling)
23+
- [Alternative: Swing Modulo Scheduling (SMS)](#alternative-swing-modulo-scheduling-sms)
2324
- [Step 2.5: Compute Cluster IDs from the Modulo Schedule](#step-25-compute-cluster-ids-from-the-modulo-schedule)
2425
- [Step 3: Derive Per-Region Pipeline Depth from the Modulo Schedule](#step-3-derive-per-region-pipeline-depth-from-the-modulo-schedule)
2526
- [Step 4: Handling Resource Pressure (SMEM/TMEM Budget)](#step-4-handling-resource-pressure-smemtmem-budget)
@@ -349,6 +350,8 @@ The algorithm as described has several limitations:
349350

350351
7. **Register allocation is approximate**: Pass B Step 4 estimates register usage from live variable counts but doesn't perform full register allocation. The actual register count is determined by the compiler backend (ptxas), which may differ from the estimate and cause spills that the schedule didn't anticipate.
351352

353+
8. **SMS limitations**: The SMS implementation's simplified ASAP/ALAP computation (no II-dependent recurrence bounds) and BFS ordering (no SCC prioritization) may produce suboptimal schedules for kernels with multiple interacting recurrence circuits, such as FA backward with 5 MMA ops and cross-iteration accumulator/softmax/pointer dependencies. For single-MMA kernels (GEMM), SMS and Rau produce identical schedules.
354+
352355
---
353356

354357
## Inputs
@@ -760,6 +763,117 @@ def modulo_schedule(DDG, latencies, unit_map, MinII):
760763
II += 1 # Try larger II
761764
```
762765

766+
#### Alternative: Swing Modulo Scheduling (SMS)
767+
768+
Swing Modulo Scheduling (J. Llosa, A. Gonzalez, E. Ayguade, M. Valero, "Swing Modulo Scheduling: A Lifetime-Sensitive Approach", PACT 1996), SMS, avoids backtracking by using a slack-based node ordering and directional placement.
769+
770+
**Key differences from Rau's IMS:**
771+
772+
| Property | Rau's IMS | SMS |
773+
|----------|-----------|-----|
774+
| Complexity | Potentially exponential (backtracking) | O(n) per II attempt |
775+
| Node ordering | Critical-path height (bottom-up) | Slack = ALAP - ASAP (tightest first) |
776+
| Placement | Earliest free slot, eject if blocked | Top-down for successors, bottom-up for predecessors |
777+
| Register pressure | Not considered | Reduced by keeping producer-consumer pairs close |
778+
779+
**SMS Algorithm:**
780+
781+
1. **Compute ASAP/ALAP**: Forward/backward relaxation including loop-carried edges (II-dependent: `ASAP[v] >= ASAP[u] + latency - distance * II`), recomputed for each candidate II. Slack = ALAP - ASAP measures scheduling freedom.
782+
783+
2. **Ordering phase (swing)**: Start with the minimum-slack op (most constrained). Then BFS-expand: add its successors (marked top-down) sorted by ascending slack, then its predecessors (marked bottom-up) sorted by ascending slack. This alternation is the "swing" — it keeps producers and consumers adjacent in the schedule.
784+
785+
3. **Scheduling phase**: For each op in swing order:
786+
- **Top-down** ops: place at the earliest free slot from `earliest` upward (data is ready, issue immediately).
787+
- **Bottom-up** ops: place at the latest free slot from `latest` downward (defer production, reducing live range and register pressure).
788+
789+
```python
790+
def sms_schedule(DDG, latencies, unit_map, MinII):
791+
for II in range(MinII, MinII + 11): # capped at MinII+10
792+
# Recompute per-II: loop-carried edges depend on II
793+
asap = compute_ASAP(DDG, latencies, II)
794+
alap = compute_ALAP(DDG, latencies, asap, II)
795+
slack = {op: alap[op] - asap[op] for op in DDG.nodes}
796+
797+
table = ReservationTable(II)
798+
scheduled = {}
799+
800+
# Ordering: BFS from min-slack seed
801+
seed = min(DDG.nodes, key=lambda n: slack[n])
802+
order = [(seed, True)] # (node, is_top_down)
803+
visited = {seed}
804+
for node, _ in order:
805+
# Successors → top-down
806+
for s in sorted(successors(node), key=lambda n: slack[n]):
807+
if s not in visited:
808+
order.append((s, True))
809+
visited.add(s)
810+
# Predecessors → bottom-up
811+
for p in sorted(predecessors(node), key=lambda n: slack[n]):
812+
if p not in visited:
813+
order.append((p, False))
814+
visited.add(p)
815+
816+
# Placement
817+
success = True
818+
for op, top_down in order:
819+
earliest = compute_earliest(op, scheduled, DDG, latencies, II)
820+
latest = compute_latest(op, scheduled, DDG, latencies, II)
821+
if top_down:
822+
slot = table.find_free(earliest, unit_map[op])
823+
else:
824+
slot = table.find_free_reverse(latest, earliest, unit_map[op])
825+
if slot is None:
826+
slot = table.find_free(earliest, unit_map[op]) # fallback
827+
if slot is None:
828+
success = False
829+
break
830+
table.reserve(slot, unit_map[op], op)
831+
scheduled[op] = slot
832+
833+
if success:
834+
return scheduled, II
835+
return None
836+
```
837+
838+
**Implementation status:** SMS is available via `TRITON_USE_MODULO_SCHEDULE=sms`. Source: `SwingScheduler.cpp`. The implementation has the following simplifications relative to the paper:
839+
840+
1. **No recurrence-aware ordering.** The paper identifies SCCs, orders them by RecMII contribution, and schedules the most critical recurrence first. The implementation uses simple BFS from the minimum-slack node.
841+
842+
2. **Fallback on placement failure.** When the directional scan finds no free slot, the implementation falls back to `find_free` from earliest. The paper would fail at this II and increment.
843+
844+
3. **BFS follows all DDG edges** including loop-carried (distance > 0). The paper's ordering only follows distance-0 edges.
845+
846+
ASAP/ALAP include loop-carried edges and are recomputed per-II: `ASAP[v] >= ASAP[u] + latency - distance * II`, with a convergence limit of 1000 iterations.
847+
848+
**selfLatency model:** All pipelines use `selfLatency = 1` because GPU execution units are deeply pipelined — a new instruction can be issued every ~1 cycle. This makes ResMII negligible (equal to the op count on the busiest pipeline) and lets RecMII (data dependencies) drive the schedule. Without this fix, SMS fails on FA backward (ResMII=4500 from 5 MMAs × 900 selfLatency each).
849+
850+
**Stage assignment (emitMMAAnnotations):** After SMS assigns cycles, the pass derives pipeline stage annotations (`tt.autows`) for MMA ops using transitive MMA dependency counting:
851+
852+
- 0-1 transitive MMA predecessors → stage 0 (can be prefetched)
853+
- 2+ transitive MMA predecessors → stage 1 (gated on multiple prior results)
854+
855+
Within each stage, independent MMAs share the same order (cluster ID) to avoid barrier deadlocks.
856+
857+
Example (FA backward, 5 MMAs):
858+
859+
| MMA | Transitive MMA deps | Stage | Order |
860+
|-----|---------------------|-------|-------|
861+
| qkT = dot(k, qT) | 0 | 0 | 0 |
862+
| dpT = dot(v, do^T) | 0 | 0 | 0 |
863+
| dv += dot(ppT, do) | 1 (qkT) | 0 | 1 |
864+
| dq = dot(dsT^T, k) | 2 (qkT, dpT) | 1 | 0 |
865+
| dk += dot(dsT, qT) | 2 (qkT, dpT) | 1 | 0 |
866+
867+
This matches the hand-tuned annotation partition exactly. Annotations are skipped when all MMAs land in the same stage (e.g., GEMM, FA forward) or when the loop already has `tt.autows` from Python `attrs=`.
868+
869+
FA BWD performance (B200, `TRITON_USE_META_WS=1 TRITON_USE_META_PARTITION=1`):
870+
871+
| Shape | Baseline TFLOPS | SMS TFLOPS | Diff |
872+
|---|---|---|---|
873+
| Z=4 H=16 N=2048 D=128 | 409.4 | 409.9 | +0.1% |
874+
| Z=8 H=16 N=1024 D=128 | 324.7 | 323.3 | -0.4% |
875+
| Z=1 H=32 N=4096 D=128 | 471.2 | 472.0 | +0.2% |
876+
763877
### Step 2.5: Compute Cluster IDs from the Modulo Schedule
764878

765879
After the modulo schedule assigns each op a `(cycle, pipeline)`, compute **cluster IDs** that encode within-stage instruction ordering for the downstream code generator.

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5252
"TRITON_DUMP_TLX_BENCHMARK",
5353
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
5454
"TRITON_PASS_PLUGIN_PATH",
55-
"TRITON_STRICT_REDUCTION_ORDERING"
55+
"TRITON_STRICT_REDUCTION_ORDERING",
56+
"TRITON_USE_MODULO_SCHEDULE"
5657
// clang-format on
5758
};
5859

python/triton/knobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ class nvidia_knobs(base_knobs):
511511
libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
512512
use_meta_ws: env_bool = env_bool("TRITON_USE_META_WS")
513513
use_meta_partition: env_bool = env_bool("TRITON_USE_META_PARTITION")
514-
use_modulo_schedule: env_bool = env_bool("TRITON_USE_MODULO_SCHEDULE")
514+
use_modulo_schedule: env_opt_str = env_opt_str("TRITON_USE_MODULO_SCHEDULE")
515515
# Force OAI SWP schedule even when using Meta's WS implementation.
516516
force_trunk_swp_schedule: env_bool = env_bool("TRITON_FORCE_TRUNK_SWP_SCHEDULE")
517517
dump_ttgir_to_tlx: env_bool = env_bool("TRITON_DUMP_TTGIR_TO_TLX")

test/TritonGPU/modulo-schedule-graph-edge.mlir

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
//===----------------------------------------------------------------------===//
55
// Edge case 0: Single-stage schedule (maxStage=0).
6-
// MMA-only loop: no TMA copy, no result use. The MMA self-latency (900) is
7-
// the only thing on the TC pipeline, so II = 900 and the MMA lands at
8-
// cycle 0, stage 0 — max_stage = 0.
6+
// MMA-only loop: no TMA copy, no result use. With selfLatency=1,
7+
// II = 1 (single TC op) and the MMA lands at cycle 0, stage 0.
98
//
109
// Regression test for Devmate review: tt.num_stages must be set even when
1110
// maxStage = 0 so downstream pipelining recognises the loop as scheduled.
@@ -18,11 +17,9 @@
1817
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1918

2019
// Verify the maxStage=0 dump and the loop's tt.num_stages=1 attribute.
21-
// CHECK: ii = 900, max_stage = 0
20+
// CHECK: ii = 1, max_stage = 0
2221
// CHECK: @maxstage_0_mma_only
23-
// CHECK: tt.modulo_ii = 900 : i32
24-
// CHECK-SAME: tt.num_stages = 1 : i32
25-
// CHECK-SAME: tt.scheduled_max_stage = 0 : i32
22+
// CHECK: tt.num_stages = 1 : i32
2623
tt.func @maxstage_0_mma_only(
2724
%a: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>,
2825
%b: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>,

test/TritonGPU/modulo-schedule-graph.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,37 @@
1313

1414
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1515

16-
// --- Graph structure: II=1038, max_stage=2, trip_count=32 ---
16+
// --- Graph structure: II=1005, max_stage=1, trip_count=32 ---
17+
// With selfLatency=1, loads issue every cycle (not every 518 cycles),
18+
// so II is driven by RecMII (loop-carried dep: MMA→tmem_load→tmem_alloc→MMA).
1719
// CHECK: [PASS-A] === Inner Loop ScheduleGraph ===
1820
// CHECK-NEXT: modulo.schedule @loop0 {
19-
// CHECK-NEXT: ii = 1038, max_stage = 2, prologue_latency = 1038, trip_count = 32
21+
// CHECK-NEXT: ii = 1005, max_stage = 1, prologue_latency = 703, trip_count = 32
2022
//
21-
// --- Nodes: loads+allocs@s0, MMA@s1, tmem_load@s2 with cluster IDs ---
23+
// --- Nodes: loads+allocs+MMA@s0, tmem_load@s1 ---
2224
// CHECK: modulo.stage @s0 {
23-
// CHECK: tt.descriptor_load {pipe: MEM, cycle: 0, cluster: 0, latency: 1218, selfLatency: 518}
24-
// CHECK: tt.descriptor_load {pipe: MEM, cycle: 518, cluster: 1, latency: 1218, selfLatency: 518}
25-
// CHECK: ttg.local_alloc {pipe: MEM, cycle: 1036, cluster: 2, latency: 700
26-
// CHECK: ttg.local_alloc {pipe: MEM, cycle: 1037, cluster: 3, latency: 700
25+
// CHECK: tt.descriptor_load {pipe: MEM, cycle: 0, cluster: 0, latency: 1218, selfLatency: 1}
26+
// CHECK: tt.descriptor_load {pipe: MEM, cycle: 1, cluster: 1, latency: 1218, selfLatency: 1}
27+
// CHECK: ttg.local_alloc {pipe: MEM, cycle: 2, cluster: 2, latency: 700
28+
// CHECK: ttg.local_alloc {pipe: MEM, cycle: 3, cluster: 3, latency: 700
29+
// CHECK: ttng.tc_gen5_mma {pipe: TC, cycle: 703, cluster: 4, latency: 900, selfLatency: 1
2730
// CHECK: }
2831
// CHECK: modulo.stage @s1 {
29-
// CHECK: ttng.tc_gen5_mma {pipe: TC, cycle: 1737, cluster: 0, latency: 900, selfLatency: 900
30-
// CHECK: }
31-
// CHECK: modulo.stage @s2 {
32-
// CHECK: ttng.tmem_load {pipe: CUDA, cycle: 2637, cluster: 0, latency: 130, selfLatency: 130
32+
// CHECK: ttng.tmem_load {pipe: CUDA, cycle: 1603, cluster: 0, latency: 105, selfLatency: 1
3333
// CHECK: }
3434
//
3535
// --- Edges: SSA + loop-carried ---
3636
// CHECK: edges {
3737
// CHECK-DAG: N0 -> N1 lat=0 dist=0
3838
// CHECK-DAG: N0 -> N2 lat=0 dist=0
39-
// CHECK-DAG: N1 -> N3 lat=518 dist=0
40-
// CHECK-DAG: N2 -> N4 lat=518 dist=0
39+
// CHECK-DAG: N1 -> N3 lat=1 dist=0
40+
// CHECK-DAG: N2 -> N4 lat=1 dist=0
4141
// CHECK-DAG: N3 -> N6 lat=700 dist=0
4242
// CHECK-DAG: N4 -> N6 lat=700 dist=0
4343
// CHECK-DAG: N5 -> N6 lat=0 dist=0
4444
// CHECK-DAG: N5 -> N7 lat=0 dist=0
4545
// CHECK-DAG: N6 -> N7 lat=900 dist=0
46-
// CHECK-DAG: N7 -> N5 lat=130 dist=1
46+
// CHECK-DAG: N7 -> N5 lat=105 dist=1
4747
// CHECK: }
4848
// CHECK: }
4949
tt.func @test_basic_graph(

test/TritonGPU/modulo-schedule.mlir

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,17 @@
88

99
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1010

11-
// Verify that the modulo schedule pass annotates ops with loop.stage/loop.cluster
12-
// and sets tt.modulo_ii on the loop.
11+
// Verify that the modulo schedule pass sets tt.num_stages on the inner loop.
12+
// For a single-MMA GEMM, all MMAs are in the same stage so tt.autows is
13+
// skipped, and inner loops no longer emit loop.stage/loop.cluster attrs
14+
// (those are only emitted on outer loops via emitScheduleAttributes).
1315
//
1416
// CHECK-LABEL: @gemm_inner_loop
15-
// Cluster IDs are dense ranks of modulo cycles within each stage (Step 2.5).
16-
// Stages processed in reverse order: higher stage -> lower cluster ID.
17-
// Same cycle -> same cluster; different cycle -> different cluster.
18-
// CHECK: tt.descriptor_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 0 : i32}
19-
// CHECK: tt.descriptor_load {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32}
20-
// CHECK: ttg.local_alloc {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32}
21-
// CHECK: ttg.local_alloc {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}
22-
// CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 : i32}
23-
// CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32}
24-
// tt.num_stages = max_stage + 1 (set so downstream pipelining recognises
25-
// the loop as scheduled, even for single-stage modulo schedules).
26-
// tt.num_buffers attrs on local_allocs are added by the next stack diff
27-
// (Phase 1 buffer allocation on ScheduleGraph).
28-
// CHECK: tt.modulo_ii = 1038 : i32
29-
// CHECK-SAME: tt.num_stages = 3 : i32
30-
// CHECK-SAME: tt.scheduled_max_stage = 2 : i32
17+
// CHECK: scf.for
18+
// CHECK-NOT: loop.stage
19+
// CHECK-NOT: loop.cluster
20+
// CHECK-NOT: tt.autows
21+
// CHECK: tt.num_stages = 2 : i32
3122
tt.func @gemm_inner_loop(
3223
%a_desc: !tt.tensordesc<tensor<128x64xf16>>,
3324
%b_desc: !tt.tensordesc<tensor<64x128xf16>>

test/TritonGPU/modulo-ws-partition.mlir

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,18 @@
88

99
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1010

11-
// Verify that Pass B assigns utilization-driven ttg.partition attrs on a
12-
// persistent kernel with a WS outer loop containing an inner K-loop.
13-
// Expected partitions: MEM=0, TC=1, CUDA(tmem_load)=2.
14-
// Shared/scalar ops get allParts [0,1,2].
11+
// Verify that the modulo schedule pass runs on the inner loop and the
12+
// ws-partition pass processes the outer WS loop. With selfLatency=1, the
13+
// single-MMA GEMM inner loop gets tt.num_stages=2 and no tt.autows
14+
// (all MMAs in same stage). The outer loop gets tt.warp_specialize.
1515
//
1616
// CHECK-LABEL: @persistent_gemm_ws_partition
17-
// MEM ops (descriptor_load, local_alloc) → partition 0
18-
// CHECK: tt.descriptor_load {{.*}} ttg.partition = array<i32: 0>
19-
// CHECK: tt.descriptor_load {{.*}} ttg.partition = array<i32: 0>
20-
// CHECK: ttg.local_alloc {{.*}} ttg.partition = array<i32: 0>
21-
// CHECK: ttg.local_alloc {{.*}} ttg.partition = array<i32: 0>
22-
// TC ops (tc_gen5_mma) → partition 1
23-
// CHECK: ttng.tc_gen5_mma {{.*}} ttg.partition = array<i32: 1>
24-
// CUDA ops (tmem_load) → partition 2
25-
// CHECK: ttng.tmem_load {{.*}} ttg.partition = array<i32: 2>
17+
// CHECK: scf.for
18+
// Inner loop has tt.num_stages from modulo schedule
19+
// CHECK: scf.for
20+
// CHECK: tt.num_stages = 2 : i32
21+
// Outer loop has tt.warp_specialize
22+
// CHECK: tt.warp_specialize
2623
tt.func @persistent_gemm_ws_partition(
2724
%a_desc: !tt.tensordesc<tensor<128x64xf16>>,
2825
%b_desc: !tt.tensordesc<tensor<64x128xf16>>,

third_party/nvidia/backend/compiler.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,24 @@ def make_ttgir(mod, metadata, opt, capability):
400400
passes.ttgpuir.add_optimize_accumulator_init(pm)
401401
passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
402402
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
403-
nvidia.passes.hopper.add_data_partitioning(pm, 1)
404-
if knobs.nvidia.use_modulo_schedule:
403+
if knobs.nvidia.use_modulo_schedule is not None:
404+
# Modulo schedule runs BEFORE data partitioning so it can
405+
# see MMA ops before they're moved into WS regions. It
406+
# sets tt.autows annotations (stage/order) on MMA ops.
407+
# TRITON_USE_MODULO_SCHEDULE=1 (default algo: rau)
408+
# TRITON_USE_MODULO_SCHEDULE=sms|exhaustive|random
405409
nvidia.passes.hopper.add_modulo_schedule(pm)
406-
else:
407-
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages, use_meta_swp_schedule)
408-
passes.ttgpuir.add_schedule_loops(pm, opt.num_stages, use_meta_swp_schedule)
410+
nvidia.passes.hopper.add_data_partitioning(pm, 1)
411+
# assign_latencies sets tt.latency on loads/MMAs (stage-distance
412+
# latencies). schedule_loops reads tt.latency AND tt.autows:
413+
# when MMA ops have tt.autows, scheduleKeyOpsAnnotation places
414+
# them at the annotated stages/clusters while scheduling all
415+
# other ops (loads, softmax, barriers) via the standard
416+
# latency-based heuristic. Without assign_latencies, the WS
417+
# pass's internal scheduleLoops has no latencies and can't
418+
# enter the code path that reads tt.autows annotations.
419+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages, use_meta_swp_schedule)
420+
passes.ttgpuir.add_schedule_loops(pm, opt.num_stages, use_meta_swp_schedule)
409421
if not knobs.nvidia.use_meta_ws:
410422
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
411423
else:

third_party/nvidia/hopper/lib/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ add_triton_library(NVHopperTransforms
2121
ModuloScheduling/LatencyModel.cpp
2222
ModuloScheduling/DataDependenceGraph.cpp
2323
ModuloScheduling/ModuloReservationTable.cpp
24+
ModuloScheduling/SwingScheduler.cpp
25+
ModuloScheduling/ExhaustiveScheduler.cpp
2426
ModuloScheduling/ModuloSchedulePass.cpp
2527
ModuloScheduling/ModuloWSPartitionPass.cpp
2628
ModuloScheduling/ModuloScheduleGraph.cpp

0 commit comments

Comments
 (0)