-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathalltoall_allpairs.py
More file actions
executable file
·28 lines (21 loc) · 1.02 KB
/
Copy pathalltoall_allpairs.py
File metadata and controls
executable file
·28 lines (21 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllToAll
# One-step AllToAll program
# Each gpu makes sends and receives a chunk from every other gpu
def alltoall(num_ranks, instances, protocol):
topology = fully_connected(num_ranks)
collective = AllToAll(num_ranks, 1, inplace=False)
with MSCCLProgram("alltoall_allpairs", topology, collective, instances=instances, protocol=protocol):
for r in range(num_ranks):
for index in range(num_ranks):
chunk(r, Buffer.input, index).copy(index, Buffer.output, r)
XML()
Check()
parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('instances', type=int, help ='number of instances')
parser.add_argument('--protocol', type=str, default='Simple', choices=['Simple', 'LL', 'LL128'], help ='NCCL protocol. Default: Simple')
args = parser.parse_args()
alltoall(args.num_gpus, args.instances, args.protocol)