File tree Expand file tree Collapse file tree 2 files changed +1
-3
lines changed
Expand file tree Collapse file tree 2 files changed +1
-3
lines changed Original file line number Diff line number Diff line change @@ -66,7 +66,6 @@ def test_kimi_checkpoint_engine(
6666 rollout_pool ,
6767 "kimi_ckpt_engine" ,
6868 checkpoint_kwargs ,
69- device = get_device_name (),
7069 check_allclose = check_allclose ,
7170 )
7271
Original file line number Diff line number Diff line change @@ -183,7 +183,7 @@ class BroadcastOperation:
183183
184184 Args:
185185 rank (int): The rank of the current process.
186- group_name (dist.ProcessGroup ): The NCCL process group.
186+ ranks_group (int ): The process group's value .
187187 bucket (torch.Tensor): The tensor to broadcast.
188188 metadata (list[ParameterMeta]): The metadata of the tensor.
189189 """
@@ -224,7 +224,6 @@ class KIMICheckpointEngine(CheckpointEngine):
224224 Args:
225225 bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
226226 two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
227- group_name (str): The name of the NCCL process group. Defaults to "default".
228227 rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
229228 is_master (bool): Whether the current process is the master process. Defaults to False.
230229 rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.
You can’t perform that action at this time.
0 commit comments