@@ -143,7 +143,7 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non
143143 weight = torch .empty (shape , dtype = dtype , device = self .communicator .device )
144144
145145 # Use NCCL to broadcast the updated weights from the client (src) to all workers.
146- self .communicator .broadcast (weight , src = self .client_rank )
146+ self .communicator .broadcast (weight , src = self .client_rank , stream = torch . cuda . current_stream () )
147147 self .communicator .group .barrier ()
148148
149149 # Patch MoE weight_loader if needed
@@ -162,7 +162,7 @@ def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, me
162162 flatten_tensor_length = metadatas [- 1 ].end_idx
163163 dtype = getattr (torch , metadatas [- 1 ].dtype .split ('.' )[- 1 ])
164164 flatten_tensor = torch .empty (flatten_tensor_length , dtype = dtype , device = self .communicator .device )
165- self .communicator .broadcast (flatten_tensor , src = self .client_rank )
165+ self .communicator .broadcast (flatten_tensor , src = self .client_rank , stream = torch . cuda . current_stream () )
166166 self .communicator .group .barrier ()
167167 flattened_tensor_bucket = FlattenedTensorBucket (metadata = metadatas , flattened_tensor = flatten_tensor )
168168 named_params = flattened_tensor_bucket .reconstruct_tensors ()
@@ -194,7 +194,7 @@ def update_adapter_param(self, lora_int_id: int, peft_config: Dict, lora_tensors
194194 dtype = getattr (torch , metadata ['dtype' ].split ('.' )[- 1 ])
195195 shape = tuple (metadata ['shape' ])
196196 tensor = torch .empty (shape , dtype = dtype , device = self .communicator .device )
197- self .communicator .broadcast (tensor , src = self .client_rank )
197+ self .communicator .broadcast (tensor , src = self .client_rank , stream = torch . cuda . current_stream () )
198198 named_params [name ] = tensor
199199
200200 self .communicator .group .barrier ()
@@ -222,7 +222,7 @@ def update_flattened_params(self, metadatas: list[Dict]) -> None:
222222 dtype = getattr (torch , metadatas [- 1 ].dtype .split ('.' )[- 1 ])
223223 flatten_tensor = torch .empty (flatten_tensor_length , dtype = dtype , device = self .communicator .device )
224224
225- self .communicator .broadcast (flatten_tensor , src = self .client_rank )
225+ self .communicator .broadcast (flatten_tensor , src = self .client_rank , stream = torch . cuda . current_stream () )
226226 self .communicator .group .barrier ()
227227
228228 flattened_tensor_bucket = FlattenedTensorBucket (metadata = metadatas , flattened_tensor = flatten_tensor )
0 commit comments