You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -349,6 +350,8 @@ The algorithm as described has several limitations:
349
350
350
351
7.**Register allocation is approximate**: Pass B Step 4 estimates register usage from live variable counts but doesn't perform full register allocation. The actual register count is determined by the compiler backend (ptxas), which may differ from the estimate and cause spills that the schedule didn't anticipate.
351
352
353
+
8.**SMS limitations**: The SMS implementation's simplified ASAP/ALAP computation (no II-dependent recurrence bounds) and BFS ordering (no SCC prioritization) may produce suboptimal schedules for kernels with multiple interacting recurrence circuits, such as FA backward with 5 MMA ops and cross-iteration accumulator/softmax/pointer dependencies. For single-MMA kernels (GEMM), SMS and Rau produce identical schedules.
Swing Modulo Scheduling (J. Llosa, A. Gonzalez, E. Ayguade, M. Valero, "Swing Modulo Scheduling: A Lifetime-Sensitive Approach", PACT 1996), SMS, avoids backtracking by using a slack-based node ordering and directional placement.
769
+
770
+
**Key differences from Rau's IMS:**
771
+
772
+
| Property | Rau's IMS | SMS |
773
+
|----------|-----------|-----|
774
+
| Complexity | Potentially exponential (backtracking) | O(n) per II attempt |
| Placement | Earliest free slot, eject if blocked | Top-down for successors, bottom-up for predecessors |
777
+
| Register pressure | Not considered | Reduced by keeping producer-consumer pairs close |
778
+
779
+
**SMS Algorithm:**
780
+
781
+
1.**Compute ASAP/ALAP**: Forward/backward relaxation including loop-carried edges (II-dependent: `ASAP[v] >= ASAP[u] + latency - distance * II`), recomputed for each candidate II. Slack = ALAP - ASAP measures scheduling freedom.
782
+
783
+
2.**Ordering phase (swing)**: Start with the minimum-slack op (most constrained). Then BFS-expand: add its successors (marked top-down) sorted by ascending slack, then its predecessors (marked bottom-up) sorted by ascending slack. This alternation is the "swing" — it keeps producers and consumers adjacent in the schedule.
784
+
785
+
3.**Scheduling phase**: For each op in swing order:
786
+
-**Top-down** ops: place at the earliest free slot from `earliest` upward (data is ready, issue immediately).
787
+
-**Bottom-up** ops: place at the latest free slot from `latest` downward (defer production, reducing live range and register pressure).
788
+
789
+
```python
790
+
defsms_schedule(DDG, latencies, unit_map, MinII):
791
+
forIIinrange(MinII, MinII +11): # capped at MinII+10
792
+
# Recompute per-II: loop-carried edges depend on II
793
+
asap = compute_ASAP(DDG, latencies, II)
794
+
alap = compute_ALAP(DDG, latencies, asap, II)
795
+
slack = {op: alap[op] - asap[op] for op inDDG.nodes}
796
+
797
+
table = ReservationTable(II)
798
+
scheduled = {}
799
+
800
+
# Ordering: BFS from min-slack seed
801
+
seed =min(DDG.nodes, key=lambdan: slack[n])
802
+
order = [(seed, True)] # (node, is_top_down)
803
+
visited = {seed}
804
+
for node, _ in order:
805
+
# Successors → top-down
806
+
for s insorted(successors(node), key=lambdan: slack[n]):
807
+
if s notin visited:
808
+
order.append((s, True))
809
+
visited.add(s)
810
+
# Predecessors → bottom-up
811
+
for p insorted(predecessors(node), key=lambdan: slack[n]):
**Implementation status:** SMS is available via `TRITON_USE_MODULO_SCHEDULE=sms`. Source: `SwingScheduler.cpp`. The implementation has the following simplifications relative to the paper:
839
+
840
+
1.**No recurrence-aware ordering.** The paper identifies SCCs, orders them by RecMII contribution, and schedules the most critical recurrence first. The implementation uses simple BFS from the minimum-slack node.
841
+
842
+
2.**Fallback on placement failure.** When the directional scan finds no free slot, the implementation falls back to `find_free` from earliest. The paper would fail at this II and increment.
843
+
844
+
3.**BFS follows all DDG edges** including loop-carried (distance > 0). The paper's ordering only follows distance-0 edges.
845
+
846
+
ASAP/ALAP include loop-carried edges and are recomputed per-II: `ASAP[v] >= ASAP[u] + latency - distance * II`, with a convergence limit of 1000 iterations.
847
+
848
+
**selfLatency model:** All pipelines use `selfLatency = 1` because GPU execution units are deeply pipelined — a new instruction can be issued every ~1 cycle. This makes ResMII negligible (equal to the op count on the busiest pipeline) and lets RecMII (data dependencies) drive the schedule. Without this fix, SMS fails on FA backward (ResMII=4500 from 5 MMAs × 900 selfLatency each).
849
+
850
+
**Stage assignment (emitMMAAnnotations):** After SMS assigns cycles, the pass derives pipeline stage annotations (`tt.autows`) for MMA ops using transitive MMA dependency counting:
Within each stage, independent MMAs share the same order (cluster ID) to avoid barrier deadlocks.
856
+
857
+
Example (FA backward, 5 MMAs):
858
+
859
+
| MMA | Transitive MMA deps | Stage | Order |
860
+
|-----|---------------------|-------|-------|
861
+
| qkT = dot(k, qT) | 0 | 0 | 0 |
862
+
| dpT = dot(v, do^T) | 0 | 0 | 0 |
863
+
| dv += dot(ppT, do) | 1 (qkT) | 0 | 1 |
864
+
| dq = dot(dsT^T, k) | 2 (qkT, dpT) | 1 | 0 |
865
+
| dk += dot(dsT, qT) | 2 (qkT, dpT) | 1 | 0 |
866
+
867
+
This matches the hand-tuned annotation partition exactly. Annotations are skipped when all MMAs land in the same stage (e.g., GEMM, FA forward) or when the loop already has `tt.autows` from Python `attrs=`.
868
+
869
+
FA BWD performance (B200, `TRITON_USE_META_WS=1 TRITON_USE_META_PARTITION=1`):
870
+
871
+
| Shape | Baseline TFLOPS | SMS TFLOPS | Diff |
872
+
|---|---|---|---|
873
+
| Z=4 H=16 N=2048 D=128 | 409.4 | 409.9 | +0.1% |
874
+
| Z=8 H=16 N=1024 D=128 | 324.7 | 323.3 | -0.4% |
875
+
| Z=1 H=32 N=4096 D=128 | 471.2 | 472.0 | +0.2% |
876
+
763
877
### Step 2.5: Compute Cluster IDs from the Modulo Schedule
764
878
765
879
After the modulo schedule assigns each op a `(cycle, pipeline)`, compute **cluster IDs** that encode within-stage instruction ordering for the downstream code generator.
0 commit comments