Skip to content

[TLE] Optimize sparse MLA forward#534

Open
sunnycase wants to merge 6 commits intotriton_v3.6.xfrom
feature/tle_sparse_mla
Open

[TLE] Optimize sparse MLA forward#534
sunnycase wants to merge 6 commits intotriton_v3.6.xfrom
feature/tle_sparse_mla

Conversation

@sunnycase
Copy link
Copy Markdown
Collaborator

@sunnycase sunnycase commented Apr 20, 2026

Summary

This PR builds on PR #489 (3.6 TLE migration for TopK/cumsum paths) and completes the sparse MLA forward work for the 3.6 branch. It includes the TLE sparse MLA tutorial/benchmark entry, WGMMA descriptor/fence support, TLE-owned staging and tile-style pipeline passes, and the native Triton hooks needed by those TLE paths under #ifdef __TLE__.

What’s Added

1. Sparse MLA forward tutorial and benchmark coverage

  • Extended python/tutorials/tle/deepseek_v32/02-sparse-mla.py with Triton, TLE, TileLang, TileLang-pipelined, TileLang-seesaw, and FlashMLA-compatible forward paths.
  • Added topk_length support so sparse MLA can skip invalid or shorter top-k regions instead of always iterating the full static topk.
  • Added FlashMLA-compatible prefill and decode input generation, plus benchmark entry points for both modes.
  • Removed the experimental TLE warp-specialized provider from the public benchmark list; this PR does not rely on Hopper WSpec for sparse MLA.

2. WGMMA descriptor view and shared operand fencing

  • Added tle.memdesc_wgmma_view as a descriptor-only view for existing shared-memory tiles consumed by WGMMA.
  • Added tle.wgmma_shared_operand_fence to order generic-proxy shared writes before WGMMA async-proxy reads.
  • Added TLE-to-LLVM lowering for WGMMA descriptor views and shared operand fences.
  • Updated NVIDIA fence insertion to emit dependency-aware TLE operand fences when WGMMA consumes shared descriptors written by async-copy/local staging paths.
  • Updated dot operand optimization so eligible local_load(existing_smem) operands can be reused directly as WGMMA shared operands instead of forcing extra register/materialization paths.

3. TLE staging and tile-style pipeline passes

  • Added triton-tle-optimize-local-pointer-async-stores to rewrite eligible global load -> local pointer store staging into direct ttg.async_copy_global_to_local over memdesc subviews.
  • Added triton-tle-promote-local-store-staging to expose safe loop-local staging as pipelineable local_alloc patterns.
  • Added triton-tle-tile-style-pipeline-schedule and triton-tle-materialize-tile-style-pipeline for TileLang-style preload/use scheduling on selected TLE loops.
  • Preserved explicit async wait counts for TLE tile-style pipeline regions where the generic WGMMA wait rewrite would otherwise erase the intended schedule.

4. Native Triton hooks guarded for TLE

  • Tightened MMAv3 chained-dot warp assignment under #ifdef __TLE__:
    • direct A/B chained-dot users keep one-axis warp assignment;
    • D/C-accumulator-only dot users no longer trigger the flash-attention chained-dot heuristic.
  • Added TLE-only short static pipeline-loop handling under #ifdef __TLE__, so static loops with trip_count <= num_stages can still expand with the needed peeled/predicated stages.
  • Kept non-TLE/native Triton behavior on the upstream code path for these hooks.

5. Regression coverage

  • Added Python regression checks for sparse MLA autotune/benchmark contracts and no-chunk TLE staging structure.
  • Added TLE MLIR tests under third_party/tle/test/GPU/ for WGMMA descriptor views, WGMMA shared operand fences, local pointer async stores, local-store staging promotion, tile-style pipeline scheduling/materialization, chained-dot A/B-vs-C detection, short static loop pipelining, and async-wait preservation.

Performance

Environment

  • GPU: NVIDIA H800
  • Device: CUDA_VISIBLE_DEVICES=6
  • Date: 2026-04-20
  • Data type: BF16
  • Benchmark timing: 600 ms warmup, 1200 ms measurement
  • Input mode: FlashMLA-compatible non-causal prefill
  • All rows use the same provider set and the same generated input per shape.

Sparse MLA Forward Prefill

Configuration for all rows: B=1, S=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, topk_length=2048 for every query/KV group.

SKV Triton ms Triton TFLOP/s TLE ms TLE TFLOP/s TLE speedup TLE max diff
8192 13.382 174.60 11.907 196.22 1.12x 0.000122
32768 16.812 138.97 15.193 153.79 1.11x 0.000122
65536 16.930 138.00 20.850 112.06 0.81x 0.000122
98304 18.266 127.91 21.571 108.31 0.85x 0.000122
131072 18.275 127.85 22.460 104.03 0.81x 0.000122

Additional condition check:

  • The earlier small-case result still shows TLE ahead: B=1, S=1024, SKV=4096, topk=512 gives TLE 0.595 ms vs Triton 0.786 ms (1.32x).
  • The PR-size prefill workload is much larger: S=4096, topk=2048, and full topk_length=2048; it should not be compared directly against the earlier small-case result.
  • A causal input generator is a different workload because topk_length varies by query position; for SKV=32768, the measured average topk_length was 1535.8, not 2048.

Correctness checks passed in the local prefill benchmark runs for TLE against the Triton sparse MLA output.


Validation

  • Backend and Python extension rebuilt successfully from this branch.
  • TLE MLIR regression tests passed for chained-dot warp assignment and short static loop pipelining.
  • Python sparse MLA regression tests passed: 5 passed in 3.79s.
  • Local sparse MLA prefill benchmark correctness checks passed against Triton outputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants