Skip to content

Commit 5d4d905

Browse files
authored
[Distributed] Add send and recv helpers (vllm-project#5719)
1 parent 6c916ac commit 5d4d905

File tree

6 files changed

+278
-24
lines changed

6 files changed

+278
-24
lines changed

tests/distributed/test_comm_ops.py

+74-4
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
import ray
99
import torch
1010

11-
from vllm.distributed import (broadcast_tensor_dict,
11+
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
1212
tensor_model_parallel_all_gather,
1313
tensor_model_parallel_all_reduce)
1414

15-
from ..utils import (init_test_distributed_environment,
16-
multi_process_tensor_parallel)
15+
from ..utils import init_test_distributed_environment, multi_process_parallel
1716

1817

1918
@ray.remote(num_gpus=1, max_calls=1)
@@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
105104
assert torch.allclose(recv_dict["f"], test_dict["f"])
106105

107106

107+
@ray.remote(num_gpus=1, max_calls=1)
108+
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
109+
distributed_init_port: str):
110+
del os.environ["CUDA_VISIBLE_DEVICES"]
111+
device = torch.device(f"cuda:{rank}")
112+
torch.cuda.set_device(device)
113+
init_test_distributed_environment(tp_size, pp_size, rank,
114+
distributed_init_port)
115+
116+
test_dict = {
117+
# device tensor
118+
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
119+
# CPU tensor
120+
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
121+
"c": "test",
122+
"d": [1, 2, 3],
123+
"e": {
124+
"a": 1,
125+
"b": 2
126+
},
127+
# empty tensor
128+
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
129+
}
130+
131+
if not get_pp_group().is_first_rank:
132+
recv_dict = get_pp_group().recv_tensor_dict()
133+
134+
if not get_pp_group().is_last_rank:
135+
get_pp_group().send_tensor_dict(test_dict)
136+
137+
if not get_pp_group().is_first_rank:
138+
assert len(recv_dict) == len(test_dict)
139+
assert torch.allclose(recv_dict["a"], test_dict["a"])
140+
assert torch.allclose(recv_dict["b"], test_dict["b"])
141+
assert recv_dict["c"] == test_dict["c"]
142+
assert recv_dict["d"] == test_dict["d"]
143+
assert recv_dict["e"] == test_dict["e"]
144+
assert torch.allclose(recv_dict["f"], test_dict["f"])
145+
146+
147+
@ray.remote(num_gpus=1, max_calls=1)
148+
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
149+
distributed_init_port: str):
150+
del os.environ["CUDA_VISIBLE_DEVICES"]
151+
device = torch.device(f"cuda:{rank}")
152+
torch.cuda.set_device(device)
153+
init_test_distributed_environment(tp_size, pp_size, rank,
154+
distributed_init_port)
155+
156+
size = 64
157+
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
158+
159+
if not get_pp_group().is_first_rank:
160+
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
161+
162+
if not get_pp_group().is_last_rank:
163+
get_pp_group().send(test_tensor)
164+
165+
if not get_pp_group().is_first_rank:
166+
assert torch.allclose(test_tensor, recv_tensor)
167+
168+
108169
@pytest.mark.skipif(torch.cuda.device_count() < 2,
109170
reason="Need at least 2 GPUs to run the test.")
110171
@pytest.mark.parametrize("tp_size", [2])
@@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
113174
broadcast_tensor_dict_test_worker
114175
])
115176
def test_multi_process_tensor_parallel(tp_size, test_target):
116-
multi_process_tensor_parallel(tp_size, 1, test_target)
177+
multi_process_parallel(tp_size, 1, test_target)
178+
179+
180+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
181+
reason="Need at least 2 GPUs to run the test.")
182+
@pytest.mark.parametrize("pp_size", [2])
183+
@pytest.mark.parametrize(
184+
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
185+
def test_multi_process_pipeline_parallel(pp_size, test_target):
186+
multi_process_parallel(1, pp_size, test_target)

tests/distributed/test_custom_all_reduce.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
get_tp_group, graph_capture)
1313

1414
from ..utils import (ensure_model_parallel_initialized,
15-
init_test_distributed_environment,
16-
multi_process_tensor_parallel)
15+
init_test_distributed_environment, multi_process_parallel)
1716

1817
random.seed(42)
1918
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
@@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
113112
world_size = tp_size * pipeline_parallel_size
114113
if world_size > torch.cuda.device_count():
115114
pytest.skip("Not enough GPUs to run the test.")
116-
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
115+
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

tests/distributed/test_pynccl.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,13 @@ def send_recv_worker_fn():
168168
dtype=torch.float32).cuda(pynccl_comm.rank)
169169
with pynccl_comm.change_state(enable=True):
170170
if pynccl_comm.rank == 0:
171-
pynccl_comm.send(tensor)
171+
pynccl_comm.send(tensor,
172+
dst=(pynccl_comm.rank + 1) %
173+
pynccl_comm.world_size)
172174
else:
173-
pynccl_comm.recv(tensor)
175+
pynccl_comm.recv(tensor,
176+
src=(pynccl_comm.rank - 1) %
177+
pynccl_comm.world_size)
174178
result = tensor.mean().cpu().item()
175179
assert result == 1
176180

@@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
203207
device=device)
204208
with pynccl_comm.change_state(enable=True):
205209
if torch.distributed.get_rank() in [0, 1]:
206-
pynccl_comm.send(tensor)
210+
pynccl_comm.send(tensor,
211+
dst=(pynccl_comm.rank + 1) %
212+
pynccl_comm.world_size)
207213
else:
208-
pynccl_comm.recv(tensor)
214+
pynccl_comm.recv(tensor,
215+
src=(pynccl_comm.rank - 1) %
216+
pynccl_comm.world_size)
209217
result = tensor.mean().cpu().item()
210218
if torch.distributed.get_rank() in [0, 2]:
211219
assert result == 1

tests/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def init_test_distributed_environment(
129129
ensure_model_parallel_initialized(tp_size, pp_size)
130130

131131

132-
def multi_process_tensor_parallel(
132+
def multi_process_parallel(
133133
tp_size: int,
134134
pp_size: int,
135135
test_target,

vllm/distributed/device_communicators/pynccl.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -121,36 +121,26 @@ def all_reduce(self,
121121
ncclRedOpTypeEnum.from_torch(op), self.comm,
122122
cudaStream_t(stream.cuda_stream))
123123

124-
def send(self,
125-
tensor: torch.Tensor,
126-
dst: Optional[int] = None,
127-
stream=None):
124+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
128125
if self.disabled:
129126
return
130127
assert tensor.device == self.device, (
131128
f"this nccl communicator is created to work on {self.device}, "
132129
f"but the input tensor is on {tensor.device}")
133130
if stream is None:
134131
stream = self.stream
135-
if dst is None:
136-
dst = (self.rank + 1) % self.world_size
137132
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
138133
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
139134
self.comm, cudaStream_t(stream.cuda_stream))
140135

141-
def recv(self,
142-
tensor: torch.Tensor,
143-
src: Optional[int] = None,
144-
stream=None):
136+
def recv(self, tensor: torch.Tensor, src: int, stream=None):
145137
if self.disabled:
146138
return
147139
assert tensor.device == self.device, (
148140
f"this nccl communicator is created to work on {self.device}, "
149141
f"but the input tensor is on {tensor.device}")
150142
if stream is None:
151143
stream = self.stream
152-
if src is None:
153-
src = (self.rank - 1) % self.world_size
154144
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
155145
ncclDataTypeEnum.from_torch(tensor.dtype), src,
156146
self.comm, cudaStream_t(stream.cuda_stream))

0 commit comments

Comments
 (0)