Lightweight herd cloning during shim DMA BD loop unrolling#1535
Lightweight herd cloning during shim DMA BD loop unrolling#1535erwei-xilinx wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR optimizes the air-opt-shim-dma-bds pass to avoid IR blow-up by skipping runtime-loop tiling and full unrolling when shim-dma-tile-sizes is empty or all-1 (the default path), while preserving scf.for loops for later/lower-cost unrolling in downstream passes.
Changes:
- Add a fast path in
AIROptimizeShimDMABDsto skip tiling/unrolling when tile sizes are empty or all-1, while still performing L3 DMA folding. - Add a new MLIR regression test covering all-1, non-trivial tiling (2), and empty-tile-size behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
mlir/lib/Transform/AIRDependencyScheduleOpt.cpp |
Adds fast-path logic to skip tiling/unrolling for empty/all-1 tile sizes and attempts to preserve downstream barrier behavior. |
mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds_skip_unroll.mlir |
New test validating skip-unroll behavior and non-trivial tiling behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
9fe9375 to
14eb133
Compare
|
Closing this PR. Hardware E2E testing on local NPU1 confirmed the change produces incorrect DMA BD configurations. Root cause: Skipping BD folding ( What doesn't work: Simply skipping the tiling/unrolling/BD-folding for all-1 tile sizes. The inner BD folding patterns ( What would work: A more surgical approach that runs BD folding on the channel ops but prevents The compilation speed optimization plan at |
f33ea01 to
f5f26c4
Compare
|
Closing this PR. The compilation speed optimization (5x on MLIR passes for flash attention) is real and validated by profiling, but all implementation approaches explored so far hit correctness issues:
The correct approach requires either:
The profiling analysis and detailed plan remain at |
f5f26c4 to
bd0d7ed
Compare
bd0d7ed to
f56b871
Compare
9ac2f96 to
400cd2a
Compare
Known limitation: tiled path not coveredThe lightweight unroller currently only fires when For the default Extending lightweight cloning to the tiled |
For shim-level scf.for loops that contain air.SegmentOp / air.HerdOp, replace the standard loopUnrollFull (which deep-clones the entire body N times) with a manual unroller that uses lightweight herd cloning. The lightweight herd clone (cloneHerdOpLightweight) creates a new herd shell via OperationState and populates the body with ONLY channel ops, allocs, deallocations, wait_alls, and their transitive operand-defining ops. Heavy compute ops (matrix multiply, vector ops, etc.) are never cloned. The segment clone (cloneSegmentOpLightweight) preserves the full segment body (L3 channel ops needed by BD folding) but applies lightweight cloning recursively to any contained air.HerdOp. This brings flash attention 12x4 (tiles=2,2) from crashing (T002) or taking O(N*body_size) time back to ~50ms, matching the main-branch baseline. Also adds a lit test (func_lightweight_unroll) that verifies channel ops are preserved and compute ops are excluded from the lightweight herd copies. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rminating walk Two efficiency improvements identified in self-review: 1. collectHerdBodyOpsToKeep: replace fixed-point iteration (copies toKeep set each round) with a worklist algorithm. Each newly-added op is pushed once and processed once, avoiding redundant re-scanning of already-kept ops in subsequent rounds. 2. loopUnrollFullWithAsyncTokenPreserved: replace two sequential walks (one for SegmentOp, one for HerdOp) with a single interruptible walk that stops as soon as the first SegmentOp or HerdOp is found. Avoids walking the entire IR a second time in the common case where a segment is present. No functional change; build and all tests pass (same 2 pre-existing ROCDL failures unrelated to T003). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
Avoid O(N × body_size) IR explosion during shim-level loop unrolling in
air-opt-shim-dma-bdsby skipping deep-clone of herd compute bodies.When
loopUnrollFullWithAsyncTokenPreserveddetects a shim-level loop (containingair.SegmentOporair.HerdOp), it uses a custom manual unroller that:OperationStateto create empty herd shells, avoiding custom builder API issuesProfiling (flash attention 12×4 launch, tiles=2,2)
air-opt-shim-dma-bdsCurrent limitation
The lightweight unroller only fires when
annotateFnis null (the non-tiled unroll path inloopUnrollFullWithAsyncTokenPreserved). The tiled path, which usesloopUnrollByFactorwith anannotateFncallback, still performs full deep-clone unrolling. Extending lightweight cloning to the tiled path is left as a follow-up.Test plan
ninja check-air-mlirpasses (365/374, only pre-existing failures)LIGHTWEIGHTFileCheck test inopt_shim_dma_bds.mlirverifies channel ops preserved, compute ops stripped