You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SymmMem][a2av] Use more CTAs for intra-node case (pytorch#153509)
Previously, we launch the a2av kernel with at most 8 blocks for intra-node cases, which turns out to saturate only 57 GB/s bandwidth.
This PR adds more blocks for intra-node, up to 8 per peer, pumping up data parallelism. The kernel now achieves 350 GB/s SOL for Hopper. See figure.
It also uses a simple tuning based on input size to avoid jumping to 8 CTAs directly (i.e. 1, 2, 4, then 8)
For inter-node, we cap at 8 blocks, since 57 GB/s seems bigger than regular NIC bandwidths (400 Gb/s).

Pull Request resolved: pytorch#153509
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#153483
0 commit comments