We are ByteDance Seed team.
You can get to know us better through the following channels👇
Original Triton README | README in Chinese
Triton-distributed is a distributed compiler designed for computation-communication overlapping, which is based on OpenAI Triton.
Using Triton-distributed, programmers are able to develop efficient kernels comparable to highly-optimized libraries (including Distributed-GEMM and FLUX). Triton-distributed currently mainly targets Nvidia GPU and AMD GPU. It can also be ported to other hardware platforms. Feel free to contact us if you want to use Triton-distributed on your own hardware.
See build from source.
Prepare PyTorch container
docker run --name triton-dist --ipc=host --network=host --privileged --cap-add=SYS_ADMIN --shm-size=10g --gpus=all -itd nvcr.io/nvidia/pytorch:25.04-py3 /bin/bash
docker exec -it triton-dist /bin/bash
Then, please download and fix NVSHMEM manually (as we cannot do this for you due to NVSHMEM license requirements). See prepare NVSHMEM.
After that, install clang-19
apt update
apt install clang-19 llvm-19 libclang-19-dev
Then, pip install triton-dist.
export NVSHMEM_SRC=/workspace/nvshmem
export CC=clang-19
export CXX=clang-19++
pip install "git+https://github.com/ByteDance-Seed/Triton-distributed.git#subdirectory=python" --no-build-isolation --force-reinstall
Triton-distributed provides a set of easy-to use primitives to support the development of distributed compute-communication overlapping kernels. The primitives are divided into low-level primitives and high-level primitives. Currently, we have released our low-level primitives, and we plan to release high-level primitives in future.
Using these primitives, users can program communication kernels easily. For example, a low-latency AllToAll (with better latency than DeepEP for inference) is shown below. The performance of this example on 32 H800 GPUs is 137us (128 tokens per rank, topk=8, hidden_size=7168, dtype=fp8), while DeepEP is 182 us (note: DeepEP doesn't use NVLink for inference).
@triton.jit
def all_to_all_kernel(
data_src,
data_dst,
splits_src,
splits_dst,
signal,
splits_cumsum,
scale_src,
scale_dst,
rank: int,
call_count: int,
WITH_SCALE: tl.constexpr,
WORLD_SIZE: tl.constexpr,
HIDDEN: tl.constexpr,
MAX_M: tl.constexpr,
EXPERTS_PER_RANK: tl.constexpr,
NUM_TOT_EXPERTS: tl.constexpr,
ELEMENT_SIZE: tl.constexpr = 2,
SCALE_ELEMENT_SIZE: tl.constexpr = 4,
):
pid = tl.program_id(0)
threadidx = tid(axis=0)
exp_st = pid * EXPERTS_PER_RANK
exp_ed = exp_st + EXPERTS_PER_RANK
m_st = tl.load(splits_cumsum + exp_st)
m_ed = tl.load(splits_cumsum + exp_ed)
num_rows_cur_block = m_ed - m_st
src_off = m_st
dst_off = rank * MAX_M
split_src_ptr = splits_src + exp_st
off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK)
off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1
cumsum_sts = tl.load(splits_cumsum + off0)
cumsum_eds = tl.load(splits_cumsum + off1)
tl.store(split_src_ptr + tl.arange(0, EXPERTS_PER_RANK), cumsum_eds - cumsum_sts)
act_pos = call_count % 2
data_dst_ptr = data_dst + act_pos * WORLD_SIZE * MAX_M * HIDDEN + dst_off * HIDDEN
split_dst_ptr = splits_dst + act_pos * NUM_TOT_EXPERTS + rank * EXPERTS_PER_RANK
signal_ptr = signal + act_pos * WORLD_SIZE + rank
libshmem_device.putmem_nbi_block(
data_dst_ptr,
data_src + src_off * HIDDEN,
num_rows_cur_block * HIDDEN * ELEMENT_SIZE,
pid,
)
libshmem_device.putmem_nbi_block(
split_dst_ptr,
split_src_ptr,
EXPERTS_PER_RANK * 4, # now we use `int32` for splits
pid,
)
if WITH_SCALE:
scale_dst_ptr = scale_dst + act_pos * WORLD_SIZE * MAX_M + dst_off
libshmem_device.putmem_signal_nbi_block(
scale_dst_ptr,
scale_src + src_off,
num_rows_cur_block * SCALE_ELEMENT_SIZE,
signal_ptr,
call_count,
libshmem_device.NVSHMEM_SIGNAL_SET,
pid,
)
libshmem_device.fence()
if threadidx == 0:
if not WITH_SCALE:
libshmem_device.signal_op(
signal_ptr,
call_count,
libshmem_device.NVSHMEM_SIGNAL_SET,
pid,
)
libshmem_device.signal_wait_until(
signal + act_pos * WORLD_SIZE + pid,
libshmem_device.NVSHMEM_CMP_EQ,
call_count,
)
Also, users can combine the communication part with computation part to design overlapping kernels. We have provided example implementations in python/triton_dist/kernels
.
Triton-distributed can achieve comparable or better performance than hand-tuned libraries.
The batch size is 1 (one query) for decoding.
- Release low-level primitives
- Release high-level primitives
- Tutorials
- Pre-built binary
- Release single-node GEMM TP overlapping kernels
- Release single-node MoE TP overlapping kernels
- Release single-node distributed Flash-Decoding kernels
- Release single-node MoE EP overlapping kernels
- Release cross-node GEMM TP overlapping kernels
- Release cross-node MoE TP overlapping kernels
- Release cross-node distributed Flash-Decoding kernels
- Release cross-node EP all-to-all kernels (similar to DeepEP)
- Provide tutorials for kernel implementation
Computation
- Nvidia SM90a support
- Nvidia SM80 support
- Nvidia SM89 support
- AMD CDNA3 support
Communication
- NVLink
- IB
- PCIe
- Performance report
The Triton-distributed project is under MIT license. Part of our code is under Apache-2.0 License:
python/triton_dist/kernels/flash_decode.py
If you use Triton-distributed in a scientific publication, we encourage you to add the following reference to the related papers:
@misc{zheng2025tritondistributed,
title={Triton-distributed: Programming Overlapping Kernels on Distributed AI Systems with the Triton Compiler},
author={Size Zheng and Wenlei Bao and Qi Hou and Xuegui Zheng and Jin Fang and Chenhui Huang and Tianqi Li and Haojie Duanmu and Renze Chen and Ruifan Xu and Yifan Guo and Ningxin Zheng and Ziheng Jiang and Xinyi Di and Dongyang Wang and Jianxi Ye and Haibin Lin and Li-Wen Chang and Liqiang Lu and Yun Liang and Jidong Zhai and Xin Liu},
year={2025},
eprint={2504.19442},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2504.19442},
}
@inproceedings{zheng2025tilelink,
author = {Size Zheng and Jin Fang and Xuegui Zheng and Qi Hou and Wenlei Bao and Ningxin Zheng and Ziheng Jiang and Dongyang Wang and Jianxi Ye and Haibin Lin and Li-Wen Chang and Xin Liu},
booktitle = {Proceedings of Machine Learning and Systems},
title = {TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives},
url = {https://arxiv.org/abs/2503.20313},
year = {2025}
}
About ByteDance Seed Team
Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.
Please use issues or pull requests for discussion and contribution (see CONTRIBUTING.md).