Skip to content

Commit c4e87ff

Browse files
committed
fix vllm>=0.14
1 parent 76e7734 commit c4e87ff

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

swift/pipelines/infer/rollout.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non
143143
weight = torch.empty(shape, dtype=dtype, device=self.communicator.device)
144144

145145
# Use NCCL to broadcast the updated weights from the client (src) to all workers.
146-
self.communicator.broadcast(weight, src=self.client_rank)
146+
self.communicator.broadcast(weight, src=self.client_rank, stream=torch.cuda.current_stream())
147147
self.communicator.group.barrier()
148148

149149
# Patch MoE weight_loader if needed
@@ -162,7 +162,7 @@ def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, me
162162
flatten_tensor_length = metadatas[-1].end_idx
163163
dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1])
164164
flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.communicator.device)
165-
self.communicator.broadcast(flatten_tensor, src=self.client_rank)
165+
self.communicator.broadcast(flatten_tensor, src=self.client_rank, stream=torch.cuda.current_stream())
166166
self.communicator.group.barrier()
167167
flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor)
168168
named_params = flattened_tensor_bucket.reconstruct_tensors()
@@ -194,7 +194,7 @@ def update_adapter_param(self, lora_int_id: int, peft_config: Dict, lora_tensors
194194
dtype = getattr(torch, metadata['dtype'].split('.')[-1])
195195
shape = tuple(metadata['shape'])
196196
tensor = torch.empty(shape, dtype=dtype, device=self.communicator.device)
197-
self.communicator.broadcast(tensor, src=self.client_rank)
197+
self.communicator.broadcast(tensor, src=self.client_rank, stream=torch.cuda.current_stream())
198198
named_params[name] = tensor
199199

200200
self.communicator.group.barrier()
@@ -222,7 +222,7 @@ def update_flattened_params(self, metadatas: list[Dict]) -> None:
222222
dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1])
223223
flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.communicator.device)
224224

225-
self.communicator.broadcast(flatten_tensor, src=self.client_rank)
225+
self.communicator.broadcast(flatten_tensor, src=self.client_rank, stream=torch.cuda.current_stream())
226226
self.communicator.group.barrier()
227227

228228
flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor)

swift/rlhf_trainers/vllm_client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def _update_single_server(i):
230230
if response.status_code != 200:
231231
raise Exception(f'Server {i} update failed: {response.text}')
232232

233-
self.pynccl_comms[i].broadcast(weights, src=self.pynccl_comms[i].rank)
233+
torch.cuda.synchronize()
234+
self.pynccl_comms[i].broadcast(
235+
weights, src=self.pynccl_comms[i].rank, stream=torch.cuda.current_stream())
236+
torch.cuda.synchronize()
234237
self.pynccl_comms[i].group.barrier()
235238
except Exception as e:
236239
errors[i] = e
@@ -275,7 +278,10 @@ def _update_single_server(i):
275278
if response.status_code != 200:
276279
raise Exception(f'Server {i} update adapter failed: {response.text}')
277280

278-
self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank)
281+
torch.cuda.synchronize()
282+
self.pynccl_comms[i].broadcast(
283+
flattened_tensor, src=self.pynccl_comms[i].rank, stream=torch.cuda.current_stream())
284+
torch.cuda.synchronize()
279285
self.pynccl_comms[i].group.barrier()
280286
except Exception as e:
281287
errors[i] = e
@@ -333,8 +339,11 @@ def _update_single_server(i):
333339
raise Exception(f'Server {i} update adapter failed: {response.text}')
334340

335341
# Broadcast each tensor individually
342+
torch.cuda.synchronize()
336343
for name, param in lora_params.items():
337-
self.pynccl_comms[i].broadcast(param, src=self.pynccl_comms[i].rank)
344+
self.pynccl_comms[i].broadcast(
345+
param, src=self.pynccl_comms[i].rank, stream=torch.cuda.current_stream())
346+
torch.cuda.synchronize()
338347
self.pynccl_comms[i].group.barrier()
339348
except Exception as e:
340349
errors[i] = e
@@ -372,7 +381,10 @@ def _update_single_server(i):
372381
if response.status_code != 200:
373382
raise Exception(f'Server {i} update flattened params failed: {response.text}')
374383

375-
self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank)
384+
torch.cuda.synchronize()
385+
self.pynccl_comms[i].broadcast(
386+
flattened_tensor, src=self.pynccl_comms[i].rank, stream=torch.cuda.current_stream())
387+
torch.cuda.synchronize()
376388
self.pynccl_comms[i].group.barrier()
377389
except Exception as e:
378390
errors[i] = e

0 commit comments

Comments
 (0)