Skip to content

Commit 34b6bab

Browse files
author
kip-cxj
committed
fix bug
1 parent d35318f commit 34b6bab

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

tests/checkpoint_engine/test_kimi_checkpoint_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_kimi_checkpoint_engine(
6666
rollout_pool,
6767
"kimi_ckpt_engine",
6868
checkpoint_kwargs,
69-
device=get_device_name(),
7069
check_allclose=check_allclose,
7170
)
7271

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class BroadcastOperation:
183183
184184
Args:
185185
rank (int): The rank of the current process.
186-
group_name (dist.ProcessGroup): The NCCL process group.
186+
ranks_group (int): The process group's value.
187187
bucket (torch.Tensor): The tensor to broadcast.
188188
metadata (list[ParameterMeta]): The metadata of the tensor.
189189
"""
@@ -224,7 +224,6 @@ class KIMICheckpointEngine(CheckpointEngine):
224224
Args:
225225
bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
226226
two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
227-
group_name (str): The name of the NCCL process group. Defaults to "default".
228227
rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
229228
is_master (bool): Whether the current process is the master process. Defaults to False.
230229
rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.

0 commit comments

Comments
 (0)