Skip to content

Commit 5f985f9

Browse files
author
kip-cxj
committed
fix bug
1 parent d35318f commit 5f985f9

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

tests/checkpoint_engine/test_kimi_checkpoint_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_kimi_checkpoint_engine(
6666
rollout_pool,
6767
"kimi_ckpt_engine",
6868
checkpoint_kwargs,
69-
device=get_device_name(),
7069
check_allclose=check_allclose,
7170
)
7271

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class BroadcastOperation:
183183
184184
Args:
185185
rank (int): The rank of the current process.
186-
group_name (dist.ProcessGroup): The NCCL process group.
186+
ranks_group (int): The process group's value.
187187
bucket (torch.Tensor): The tensor to broadcast.
188188
metadata (list[ParameterMeta]): The metadata of the tensor.
189189
"""
@@ -224,7 +224,6 @@ class KIMICheckpointEngine(CheckpointEngine):
224224
Args:
225225
bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
226226
two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
227-
group_name (str): The name of the NCCL process group. Defaults to "default".
228227
rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
229228
is_master (bool): Whether the current process is the master process. Defaults to False.
230229
rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.
@@ -273,7 +272,8 @@ def init_process_group(self, rank: int, world_size: int, master_metadata: Master
273272
world_size (int): The total number of processes.
274273
"""
275274
self.rank = rank
276-
# unregister_memory in transfer engine is not supported on NPU, so we have to initialize ParameterServer each time
275+
# unregister_memory in transfer engine is not supported on NPU,
276+
# so we have to initialize ParameterServer each time
277277
if get_device_name() == "npu" or not self.initialized:
278278
self.parameter_server = ParameterServer(rank=rank, world_size=world_size, auto_pg=False, custom_dist=True)
279279
self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)

0 commit comments

Comments
 (0)