Skip to content

perf(dyt): enable triton.autotune for DyT forward and backward kernels#1208

Open
shivam2199 wants to merge 2 commits intolinkedin:mainfrom
shivam2199:perf/dyt-autotune
Open

perf(dyt): enable triton.autotune for DyT forward and backward kernels#1208
shivam2199 wants to merge 2 commits intolinkedin:mainfrom
shivam2199:perf/dyt-autotune

Conversation

@shivam2199
Copy link
Copy Markdown
Contributor

Summary

Enables @triton.autotune on the DyT (Dynamic Tanh) forward and backward kernels. The autotune configs were drafted by the original author but left commented out with a static BLOCK_N heuristic in the Python wrappers (liger_dyt_fwd / liger_dyt_bwd) as a placeholder.

Changes

  1. Enable @triton.autotune on both _dyt_fwd_kernel and _dyt_bwd_kernel with a pruned config space (18 configs each):
    • BLOCK_N in {1024, 2048, 4096}
    • num_stages in {1, 2} (dropped num_stages=4 from the original stub — rarely wins for memory-bound elementwise kernels)
    • num_warps in {4, 8, 16} (dropped num_warps=32 from fwd stub — overkill for this kernel)
  2. Remove the manual BLOCK_N heuristic (if N >= 4096:) and manual kwargs in both Python wrappers — autotune handles selection.
  3. Add reset_to_zero=[\"DA\"] to the backward autotune decorator. The DA reduction buffer is indexed by program_id(0), so different BLOCK_N configs write to different slot counts per SM. Autotune trials don't zero output buffers between runs, so stale slots from a prior config would leak into the final da.sum() of the winning config. Same pattern as used in fused_moe_kernels.py.

Performance (NVIDIA A10G, sm_86, bf16, BT=2048)

Measured via benchmark/scripts/benchmark_dyt.py --sweep-mode model_config across 7 model hidden sizes (2048 -> 8192) x both beta=True/beta=False x fwd/bwd/full. Median over do_bench trials.

Forward pass (consistent wins):

Model hidden main autotuned speedup
deepseek_v2_lite 2048 0.039ms 0.039ms 1.00x
qwen2.5_7b 3584 0.066ms 0.062ms 1.05x
llama_2_7b 4096 0.075ms 0.072ms 1.04x
llama_3_8b 4096 0.075ms 0.072ms 1.04x
qwen2.5_14b 5120 0.092ms 0.089ms 1.03x
deepseek_v3 7168 0.127ms 0.124ms 1.02x
qwen2.5_72b 8192 0.144ms 0.140ms 1.03x

Full (fwd + bwd):

Model main autotuned speedup
deepseek_v2_lite 0.137ms 0.136ms 1.01x
qwen2.5_7b 0.219ms 0.214ms 1.02x
llama_2_7b 0.248ms 0.244ms 1.01x
llama_3_8b 0.249ms 0.246ms 1.01x
qwen2.5_14b 0.304ms 0.298ms 1.02x
deepseek_v3 0.419ms 0.404ms 1.04x
qwen2.5_72b 0.466ms 0.462ms 1.01x

Zero regressions observed across 42 (beta x mode x model) combinations (max noise: -1% on 2 configs, within do_bench variance). Memory unchanged.

Backward shows smaller improvements because it's dominated by the per-SM rows_per_program loop over M rather than the column axis.

Caveats

  • Autotune compile cost (~9s per unique N) is paid once per process and absorbed by triton.testing.do_bench's built-in warmup, so measured speeds above are post-warmup steady-state.
  • Config space was chosen based on A10G (Ampere) benchmarks. Hopper/Blackwell users may benefit from additional configs; Triton's autotune will still pick the best available from the current set.

Testing Done

  • Hardware: NVIDIA A10G (sm_86, Ampere)
  • pytest test/transformers/test_dyt.py -xvs -> 24 passed (correctness + functional, both beta=True/beta=False, dtypes fp32 and bf16)
  • python benchmark/scripts/benchmark_dyt.py --sweep-mode model_config --overwrite
  • run make test to ensure correctness (ran pytest test/transformers/test_dyt.py — scoped to changed kernel)
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence (not applicable — no numerical changes; autotune is a performance-only change)

Author had drafted autotune decorators but left them commented out with
a static BLOCK_N heuristic in the Python wrappers instead. This enables
autotune with a pruned config space (18 configs per kernel: BLOCK_N in
{1024, 2048, 4096}, num_stages in {1, 2}, num_warps in {4, 8, 16}).

Dropped num_stages=4 and num_warps=32 from the original stubs since
elementwise kernels rarely benefit from deeper pipelining or wide warps.
The da reduction buffer stays at cdiv(N, 512) — safely over-provisions
for all BLOCK_N >= 512 in the autotune space.
DA is indexed by program_id(0), so the number of slots written per SM
scales with cdiv(N, BLOCK_N). Autotune runs every config back-to-back
without zeroing output buffers between trials, so slots written by a
prior BLOCK_N config leak into the final da.sum() of the winning
config. reset_to_zero=["DA"] forces autotune to zero DA before each
trial, matching the pattern used in fused_moe_kernels.py.
@shivam2199
Copy link
Copy Markdown
Contributor Author

@Tcc0403 @Mecoli1219 Can you check this omce?

Copy link
Copy Markdown
Collaborator

@Mecoli1219 Mecoli1219 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants