8
8
import ray
9
9
import torch
10
10
11
- from vllm .distributed import (broadcast_tensor_dict ,
11
+ from vllm .distributed import (broadcast_tensor_dict , get_pp_group ,
12
12
tensor_model_parallel_all_gather ,
13
13
tensor_model_parallel_all_reduce )
14
14
15
- from ..utils import (init_test_distributed_environment ,
16
- multi_process_tensor_parallel )
15
+ from ..utils import init_test_distributed_environment , multi_process_parallel
17
16
18
17
19
18
@ray .remote (num_gpus = 1 , max_calls = 1 )
@@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
105
104
assert torch .allclose (recv_dict ["f" ], test_dict ["f" ])
106
105
107
106
107
+ @ray .remote (num_gpus = 1 , max_calls = 1 )
108
+ def send_recv_tensor_dict_test_worker (tp_size : int , pp_size : int , rank : int ,
109
+ distributed_init_port : str ):
110
+ del os .environ ["CUDA_VISIBLE_DEVICES" ]
111
+ device = torch .device (f"cuda:{ rank } " )
112
+ torch .cuda .set_device (device )
113
+ init_test_distributed_environment (tp_size , pp_size , rank ,
114
+ distributed_init_port )
115
+
116
+ test_dict = {
117
+ # device tensor
118
+ "a" : torch .arange (8 , dtype = torch .float32 , device = "cuda" ),
119
+ # CPU tensor
120
+ "b" : torch .arange (16 , dtype = torch .int8 , device = "cpu" ),
121
+ "c" : "test" ,
122
+ "d" : [1 , 2 , 3 ],
123
+ "e" : {
124
+ "a" : 1 ,
125
+ "b" : 2
126
+ },
127
+ # empty tensor
128
+ "f" : torch .tensor ([], dtype = torch .float32 , device = "cuda" ),
129
+ }
130
+
131
+ if not get_pp_group ().is_first_rank :
132
+ recv_dict = get_pp_group ().recv_tensor_dict ()
133
+
134
+ if not get_pp_group ().is_last_rank :
135
+ get_pp_group ().send_tensor_dict (test_dict )
136
+
137
+ if not get_pp_group ().is_first_rank :
138
+ assert len (recv_dict ) == len (test_dict )
139
+ assert torch .allclose (recv_dict ["a" ], test_dict ["a" ])
140
+ assert torch .allclose (recv_dict ["b" ], test_dict ["b" ])
141
+ assert recv_dict ["c" ] == test_dict ["c" ]
142
+ assert recv_dict ["d" ] == test_dict ["d" ]
143
+ assert recv_dict ["e" ] == test_dict ["e" ]
144
+ assert torch .allclose (recv_dict ["f" ], test_dict ["f" ])
145
+
146
+
147
+ @ray .remote (num_gpus = 1 , max_calls = 1 )
148
+ def send_recv_test_worker (tp_size : int , pp_size : int , rank : int ,
149
+ distributed_init_port : str ):
150
+ del os .environ ["CUDA_VISIBLE_DEVICES" ]
151
+ device = torch .device (f"cuda:{ rank } " )
152
+ torch .cuda .set_device (device )
153
+ init_test_distributed_environment (tp_size , pp_size , rank ,
154
+ distributed_init_port )
155
+
156
+ size = 64
157
+ test_tensor = torch .arange (64 , dtype = torch .float32 , device = "cuda" )
158
+
159
+ if not get_pp_group ().is_first_rank :
160
+ recv_tensor = get_pp_group ().recv (size , dtype = torch .float32 )
161
+
162
+ if not get_pp_group ().is_last_rank :
163
+ get_pp_group ().send (test_tensor )
164
+
165
+ if not get_pp_group ().is_first_rank :
166
+ assert torch .allclose (test_tensor , recv_tensor )
167
+
168
+
108
169
@pytest .mark .skipif (torch .cuda .device_count () < 2 ,
109
170
reason = "Need at least 2 GPUs to run the test." )
110
171
@pytest .mark .parametrize ("tp_size" , [2 ])
@@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
113
174
broadcast_tensor_dict_test_worker
114
175
])
115
176
def test_multi_process_tensor_parallel (tp_size , test_target ):
116
- multi_process_tensor_parallel (tp_size , 1 , test_target )
177
+ multi_process_parallel (tp_size , 1 , test_target )
178
+
179
+
180
+ @pytest .mark .skipif (torch .cuda .device_count () < 2 ,
181
+ reason = "Need at least 2 GPUs to run the test." )
182
+ @pytest .mark .parametrize ("pp_size" , [2 ])
183
+ @pytest .mark .parametrize (
184
+ "test_target" , [send_recv_test_worker , send_recv_tensor_dict_test_worker ])
185
+ def test_multi_process_pipeline_parallel (pp_size , test_target ):
186
+ multi_process_parallel (1 , pp_size , test_target )
0 commit comments