-
Notifications
You must be signed in to change notification settings - Fork 100
Expand file tree
/
Copy pathmulti_node_transfer.py
More file actions
85 lines (71 loc) · 3.26 KB
/
multi_node_transfer.py
File metadata and controls
85 lines (71 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Multi-Node Transfer Test
This file tests the SIGNAL, WAIT, PUT, PUT_WITH_SIGNAL and
PUT_WITH_SIGNAL_AND_FLUSH operations on PortChannels in a multi-node
environment. It implements a 2-GPU allgather using the Simple protocol,
exercising the different port-channel synchronization primitives.
"""
import argparse
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
from mscclpp.language.program import *
from mscclpp.language.collectives import *
def multi_node_transfer(name, num_threads_per_block, min_message_size, max_message_size):
chunksperloop = 2
gpu_size = 2
collective = AllGather(gpu_size, chunksperloop, True)
with CollectiveProgram(
name,
collective,
gpu_size,
protocol="Simple",
num_threads_per_block=num_threads_per_block,
use_double_scratch_buffer=False,
min_message_size=min_message_size,
max_message_size=max_message_size,
):
# Setup ranks, channels, output and scratch buffers for 2-GPU allgather
first_rank = Rank(0)
second_rank = Rank(1)
first_ch1 = PortChannel(1, 0)
second_ch1 = PortChannel(0, 1)
first_ch2 = PortChannel(1, 0)
second_ch2 = PortChannel(0, 1)
first_output_buffer = first_rank.get_output_buffer()
second_output_buffer = second_rank.get_output_buffer()
# Initial handshake on both port channels: peers exchange SIGNAL/WAIT to
# ensure remote buffers are ready before any data transfer begins.
first_ch1.signal(tb=0)
second_ch1.signal(tb=0)
first_ch1.wait(tb=0)
second_ch1.wait(tb=0)
first_ch2.signal(tb=1)
second_ch2.signal(tb=1)
first_ch2.wait(tb=1)
second_ch2.wait(tb=1)
# Rank 0 -> rank 1 via ch1: PUT followed by an explicit SIGNAL and FLUSH
first_ch1.put(second_output_buffer[0:1], first_output_buffer[0:1], tb=0)
first_ch1.signal(tb=0)
first_ch1.flush(tb=0)
# Rank 0 -> rank 1 via ch2: PUT_WITH_SIGNAL fuses the data transfer with
# the completion signal, followed by a separate FLUSH
first_ch2.put_with_signal(second_output_buffer[1:2], first_output_buffer[1:2], tb=1)
first_ch2.flush(tb=1)
# Rank 1 -> rank 0 via ch1: PUT_WITH_SIGNAL_AND_FLUSH fuses PUT, SIGNAL
# and FLUSH into a single operation
second_ch1.put_with_signal_and_flush(first_output_buffer[2:4], second_output_buffer[2:4], tb=0)
# Final WAITs ensure all incoming transfers have completed on each rank
first_ch1.wait(tb=0)
second_ch1.wait(tb=0)
second_ch2.wait(tb=1)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
args = parser.parse_args()
multi_node_transfer(args.name, args.num_threads_per_block, args.min_message_size, args.max_message_size)