Skip to content

Triton + Symmetric Memory fused all2all dispatch + token group padding kernel #4050

@danielvegamyhre

Description

@danielvegamyhre

We landed some decent CUDA Kernels to do per group padding the "naive" way (extra kernel that incurs an extra copy of the large input activations).

Fusing the padding into the all2all needed to avoid the extra copy incurred by the standalone padding kernel described above, which hurt our speedup. The benefit of this approach is that, with the all-to-all dispatch, the receiver ranks are already going to be allocating a buffer for the incoming tokens. If we write those tokens to locations aligned with multiples of 32, we avoid the need for this expensive extra copy.
While we're doing this, we can also write incoming tokens grouped by local expert, instead of grouped by remote/source rank, in order to avoid the token shuffle kernel step (i.e., 2d on-device all to all with no d2h sync)

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions