Skip to content

Commit 8dc26f1

Browse files
author
kip-cxj
committed
modify readme
1 parent 2699df4 commit 8dc26f1

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

verl/checkpoint_engine/README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,25 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
2020
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)<br>- UCX<br>- UCCL<br>- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training<br>- Trainer/rollout disaggregated<br>- Elastic rollout<br>- Rollout fault tolerance<br>- Heterogeneous hardware rollout
2121
|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Save checkpoint each time
2222

23-
PS: kimi_ckpt_engine first offloads all weights to the CPU. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
23+
##### kimi_ckpt_engine detail:
24+
25+
In the kimi_ckpt_engine workflow, the trainer first offloads the weights to the CPU, and the rollout creates a sub communication group that includes all the cards for the rollout. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
26+
27+
<img src="https://github.com/kip-cxj/verl/blob/cxj/doc_imgs/docs/_static/kimi_ckpt_engine.png?raw=true" alt="kimi-ckpt-engine" width="50%">
28+
29+
This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher.
30+
31+
In addition, during the installation of checkpoint-engine[p2p], the transfer engine will be installed. However, This library has no prebuilt packages for Ascend devices and must be compiled from source. For detailed compilation instructions, see: [transfer-engine: ascend direct](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/design/transfer-engine/ascend_direct_transport.md)
2432

2533
### Benchmark
2634
1. benchmark setup
2735
- model: Qwen/Qwen3-30B-A3B-Base
2836
- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
2937
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
3038
```bash
31-
python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
32-
python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
33-
python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
34-
python3 tests/checkpoint_engine/test_kimi_checkpoint_engine.py
39+
pytest tests/checkpoint_engine/test_correctness_on_gpu.py
40+
pytest tests/checkpoint_engine/test_correctness_on_npu.py
41+
pytest tests/checkpoint_engine/test_special_server_adapter.py
3542
```
3643

3744
2. benchmark result

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor
2828

2929
from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry
30-
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
30+
from verl.utils.device import get_nccl_backend, get_torch_device
3131
from verl.utils.net_utils import get_free_port
3232

3333
logger = logging.getLogger(__name__)
@@ -331,7 +331,7 @@ def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
331331
start_time = time.time()
332332
named_tensors = {}
333333
for named_tensors_gpu in ckpt_get_named_tensor_buckets(
334-
weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype
334+
weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype
335335
):
336336
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
337337
futures = [

0 commit comments

Comments
 (0)