Skip to content

Commit cd6d4c1

Browse files
author
kip-cxj
committed
modify readme
1 parent 2699df4 commit cd6d4c1

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

verl/checkpoint_engine/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
2020
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)<br>- UCX<br>- UCCL<br>- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training<br>- Trainer/rollout disaggregated<br>- Elastic rollout<br>- Rollout fault tolerance<br>- Heterogeneous hardware rollout
2121
|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Save checkpoint each time
2222

23-
PS: kimi_ckpt_engine first offloads all weights to the CPU. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
23+
##### kimi_ckpt_engine detail:
24+
25+
In the kimi_ckpt_engine workflow, the trainer first offloads the weights to the CPU, and the rollout creates a sub communication group that includes all the cards for the rollout. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
26+
27+
<img src="https://github.com/kip-cxj/verl/blob/cxj/doc_imgs/docs/_static/kimi_ckpt_engine.png?raw=true" alt="kimi-ckpt-engine" width="50%">
28+
29+
This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher.
2430

2531
### Benchmark
2632
1. benchmark setup

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor
2828

2929
from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry
30-
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
30+
from verl.utils.device import get_nccl_backend, get_torch_device
3131
from verl.utils.net_utils import get_free_port
3232

3333
logger = logging.getLogger(__name__)
@@ -331,7 +331,7 @@ def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
331331
start_time = time.time()
332332
named_tensors = {}
333333
for named_tensors_gpu in ckpt_get_named_tensor_buckets(
334-
weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype
334+
weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype
335335
):
336336
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
337337
futures = [

0 commit comments

Comments
 (0)