Skip to content

add fuse ar rms op#48

Open
crazy-JiangDongHua wants to merge 6 commits into
Tencent:mainfrom
crazy-JiangDongHua:Feature/Fuse_AR_RMSNorm
Open

add fuse ar rms op#48
crazy-JiangDongHua wants to merge 6 commits into
Tencent:mainfrom
crazy-JiangDongHua:Feature/Fuse_AR_RMSNorm

Conversation

@crazy-JiangDongHua
Copy link
Copy Markdown

@crazy-JiangDongHua crazy-JiangDongHua commented Jun 2, 2026

Fuses tensor-parallel AllReduce, residual add, and RMSNorm into one NVLink-native op — RMSNorm(AllReduce(x) + residual, weight) — avoiding extra kernel launches and HBM round-trips. Built on CUDA multicast (multimem) and P2P. bfloat16, single-node multi-GPU. Two modes are provided, both built on a two-shot (reduce-scatter + all-gather) schedule:

High-throughput mode (fuse_allreduce_rmsnorm_high_throughput): a single fused kernel that performs the reduction over NVSwitch multicast — best for large token counts (prefill).
Low-latency mode (fuse_allreduce_rmsnorm_low_latency): Lamport P2P exchange split into two kernels overlapped via PDL — best for small token counts (decode); requires a power-of-two world size.

See benchmark/bench_allreduce_rmsnorm.py for an 8-GPU comparison against NCCL and FlashInfer.

@crazy-JiangDongHua crazy-JiangDongHua force-pushed the Feature/Fuse_AR_RMSNorm branch from 60c508d to d1da6b5 Compare June 3, 2026 11:49
Co-authored-by: lvjx04 <2108244896@qq.com>
@crazy-JiangDongHua crazy-JiangDongHua force-pushed the Feature/Fuse_AR_RMSNorm branch from d1da6b5 to 8a49bd1 Compare June 3, 2026 11:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants