perf(dyt): enable triton.autotune for DyT forward and backward kernels#1208
Open
shivam2199 wants to merge 2 commits intolinkedin:mainfrom
Open
perf(dyt): enable triton.autotune for DyT forward and backward kernels#1208shivam2199 wants to merge 2 commits intolinkedin:mainfrom
shivam2199 wants to merge 2 commits intolinkedin:mainfrom
Conversation
Author had drafted autotune decorators but left them commented out with
a static BLOCK_N heuristic in the Python wrappers instead. This enables
autotune with a pruned config space (18 configs per kernel: BLOCK_N in
{1024, 2048, 4096}, num_stages in {1, 2}, num_warps in {4, 8, 16}).
Dropped num_stages=4 and num_warps=32 from the original stubs since
elementwise kernels rarely benefit from deeper pipelining or wide warps.
The da reduction buffer stays at cdiv(N, 512) — safely over-provisions
for all BLOCK_N >= 512 in the autotune space.
DA is indexed by program_id(0), so the number of slots written per SM scales with cdiv(N, BLOCK_N). Autotune runs every config back-to-back without zeroing output buffers between trials, so slots written by a prior BLOCK_N config leak into the final da.sum() of the winning config. reset_to_zero=["DA"] forces autotune to zero DA before each trial, matching the pattern used in fused_moe_kernels.py.
Contributor
Author
|
@Tcc0403 @Mecoli1219 Can you check this omce? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enables
@triton.autotuneon the DyT (Dynamic Tanh) forward and backward kernels. The autotune configs were drafted by the original author but left commented out with a staticBLOCK_Nheuristic in the Python wrappers (liger_dyt_fwd/liger_dyt_bwd) as a placeholder.Changes
@triton.autotuneon both_dyt_fwd_kerneland_dyt_bwd_kernelwith a pruned config space (18 configs each):BLOCK_N in {1024, 2048, 4096}num_stages in {1, 2}(droppednum_stages=4from the original stub — rarely wins for memory-bound elementwise kernels)num_warps in {4, 8, 16}(droppednum_warps=32from fwd stub — overkill for this kernel)BLOCK_Nheuristic (if N >= 4096:) and manualkwargsin both Python wrappers — autotune handles selection.reset_to_zero=[\"DA\"]to the backward autotune decorator. TheDAreduction buffer is indexed byprogram_id(0), so differentBLOCK_Nconfigs write to different slot counts per SM. Autotune trials don't zero output buffers between runs, so stale slots from a prior config would leak into the finalda.sum()of the winning config. Same pattern as used infused_moe_kernels.py.Performance (NVIDIA A10G, sm_86, bf16, BT=2048)
Measured via
benchmark/scripts/benchmark_dyt.py --sweep-mode model_configacross 7 model hidden sizes (2048 -> 8192) x bothbeta=True/beta=Falsex fwd/bwd/full. Median overdo_benchtrials.Forward pass (consistent wins):
Full (fwd + bwd):
Zero regressions observed across 42 (beta x mode x model) combinations (max noise: -1% on 2 configs, within
do_benchvariance). Memory unchanged.Backward shows smaller improvements because it's dominated by the per-SM
rows_per_programloop over M rather than the column axis.Caveats
N) is paid once per process and absorbed bytriton.testing.do_bench's built-in warmup, so measured speeds above are post-warmup steady-state.Testing Done
pytest test/transformers/test_dyt.py -xvs-> 24 passed (correctness + functional, bothbeta=True/beta=False, dtypes fp32 and bf16)python benchmark/scripts/benchmark_dyt.py --sweep-mode model_config --overwritemake testto ensure correctness (ranpytest test/transformers/test_dyt.py— scoped to changed kernel)make checkstyleto ensure code stylemake test-convergenceto ensure convergence (not applicable — no numerical changes; autotune is a performance-only change)