@@ -183,7 +183,7 @@ class BroadcastOperation:
183183
184184 Args:
185185 rank (int): The rank of the current process.
186- group_name (dist.ProcessGroup ): The NCCL process group.
186+ ranks_group (int ): The process group's value .
187187 bucket (torch.Tensor): The tensor to broadcast.
188188 metadata (list[ParameterMeta]): The metadata of the tensor.
189189 """
@@ -224,7 +224,6 @@ class KIMICheckpointEngine(CheckpointEngine):
224224 Args:
225225 bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
226226 two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
227- group_name (str): The name of the NCCL process group. Defaults to "default".
228227 rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
229228 is_master (bool): Whether the current process is the master process. Defaults to False.
230229 rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.
@@ -273,7 +272,8 @@ def init_process_group(self, rank: int, world_size: int, master_metadata: Master
273272 world_size (int): The total number of processes.
274273 """
275274 self .rank = rank
276- # unregister_memory in transfer engine is not supported on NPU, so we have to initialize ParameterServer each time
275+ # unregister_memory in transfer engine is not supported on NPU,
276+ # so we have to initialize ParameterServer each time
277277 if get_device_name () == "npu" or not self .initialized :
278278 self .parameter_server = ParameterServer (rank = rank , world_size = world_size , auto_pg = False , custom_dist = True )
279279 self .parameter_server .receive_tensor = types .MethodType (receive_tensor , self .parameter_server )
0 commit comments