Skip to content

[Feature Request] Fine-grained communication API #65

@jinhongyii

Description

@jinhongyii

Describe the feature

While coarse-grained NCCL API can already well handle collective communications like allreduce/allgather/..., AllToAll dispatch/combine in MOE model poses new challenges for efficient communication. Alltoall combine/dispatch is a sparse process, which means the communication sizes between each device pairs are different. Compared to organizing communication carefully with more fine-grained primitives, the standard NCCL implementation can be 10x slower (perf number from https://www.perplexity.ai/hub/blog/efficient-and-portable-mixture-of-experts-communication).

Here we show how https://github.com/ppl-ai/pplx-kernels writes alltoall kernel. Below is pseudocode of dispatch kernel

local_rank = get_rank()
# in the original implementation this is done on separate warp
# aggregate routing information
for exp in experts:
    num_tokens = num_tokens_to_expert[exp] # count tokens routed to exp
    exp_rank = expert2rank[exp] # which rank hosts the expert
    signal(total_count[local_rank, exp], exp_rank, exp, num_tokens) # signal to the remote rank that `num_tokens` tokens needs to be received for exp and store it into `total_count[local_rank, exp]`
# start send
for tok in local_tokens: 
    for exp in routed_experts[tok]: # in original implementation this work is done by multiple warps 
        putmem_signal_nbi(data_recv, rank, exp, tok, recv_count[local_rank, exp], "add", 1) # nonblocking remote write of tok to exp on rank and atomically add the recv count by 1
# start recv
for rank in ranks:
   for exp in local_experts:
      wait_until(total_count[rank, exp]) # wait for total_count[rank, exp] to be signaled
      wait_until(recv_count[rank, exp] == total_count[rank, exp]) # wait until all the tokens of exp are recved from `rank`

# pack the tokens of the same expert together
 ...

Here shows 3 APIs used: putmem_signal_nbi, signal, and wait_until (the actual function names are nvshmem_TYPENAME_put_signal_nbi, nvshmemx_signal_op, and nvshmem_TYPENAME_wait_until). The APIs are provided by NVSHMEM. Full specification can be found here: https://docs.nvidia.com/nvshmem/api/api.html

The main benefit brought by NVSHMEM APIs are:

  1. async communication (overlap with other computation)
  2. one-sided communication (no need for coupled send-recv pairs. receive logic can happen at different time from sending logic)
  3. fine-grained primitive giving user better control

If NKI can expose similar fine-grained API like NVSHMEM, user can have more space to optimize communication.

Use Case

More efficient all-to-all dispatch/combine communication. SOTA GPU communication libraries are all using NVSHMEM. See https://github.com/bytedance/flux (fine-grained all-to-all and groupgemm overlapping), https://github.com/deepseek-ai/DeepEP and https://github.com/ppl-ai/pplx-kernels for single all-to-all kernel.

Proposed Solution

No response

Other Information

No response

Acknowledgements

  • I may be able to implement this feature request
  • This feature might incur a breaking change

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions