Skip to content

Commit 3f62182

Browse files
author
yexin
committed
[ckpt] feat: add mooncake backend
1 parent 7b24498 commit 3f62182

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

tests/checkpoint_engine/test_correctness_on_npu.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,54 @@ async def test_hccl_checkpoint_engine(
7474
ray.shutdown()
7575

7676

77+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
78+
@pytest.mark.asyncio
79+
@pytest.mark.parametrize("device", ["npu"])
80+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
81+
async def test_mooncake_checkpoint_engine(
82+
rebuild_group,
83+
num_trainer,
84+
num_rollout,
85+
device,
86+
num_nodes=1,
87+
num_gpus_per_node=8,
88+
check_allclose=True,
89+
model_path="~/models/Qwen/Qwen3-8B-Base",
90+
):
91+
model_path = os.path.expanduser(model_path)
92+
ray.init(
93+
runtime_env={
94+
"env_vars": {
95+
"ASCEND_USE_SHORT_CONNECTION": "1",
96+
"VERL_LOGGING_LEVEL": "DEBUG",
97+
}
98+
}
99+
)
100+
101+
# initialize config
102+
checkpoint_engine_config = CheckpointEngineConfig(
103+
backend="mooncake", engine_kwargs={"mooncake": {"device": device, "rebuild_group": rebuild_group}}
104+
)
105+
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
106+
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)
107+
108+
# create trainer and rollout worker group
109+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
110+
resource_pool.get_placement_groups(device_name=get_device_name())
111+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
112+
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
113+
trainer.reset()
114+
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)
115+
116+
# create checkpoint engine manager
117+
checkpoint_manager = CheckpointEngineManager(backend="mooncake", trainer=trainer, replicas=replicas)
118+
for _ in range(3):
119+
await checkpoint_manager.update_weights()
120+
rollout.check_weights()
121+
122+
ray.shutdown()
123+
124+
77125
if __name__ == "__main__":
78126
test_hccl_checkpoint_engine(
79127
rebuild_group=False,

verl/checkpoint_engine/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,10 @@
5151
__all__ += ["NIXLCheckpointEngine"]
5252
except ImportError:
5353
NIXLCheckpointEngine = None
54+
55+
try:
56+
from .mooncake_checkpoint_engine import MooncakeCheckpointEngine
57+
58+
__all__ += ["MoonCakeCheckpointEngine"]
59+
except ImportError:
60+
MooncakeCheckpointEngine = None
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
import logging
16+
import os
17+
import time
18+
import gc
19+
from collections import defaultdict
20+
from dataclasses import dataclass
21+
from typing import Any, AsyncGenerator, Generator
22+
23+
import ray
24+
import torch
25+
from vllm.distributed.utils import StatelessProcessGroup
26+
from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta
27+
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address
28+
from verl.utils.device import get_torch_device
29+
30+
from mooncake.engine import TransferEngine
31+
32+
logger = logging.getLogger(__name__)
33+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
34+
35+
36+
@CheckpointEngineRegistry.register("mooncake")
37+
class MooncakeCheckpointEngine(CheckpointEngine):
38+
"""Mooncake checkpoint engine with p2p communication using TransferEngine
39+
40+
Args:
41+
bucket_size (int): Bucket size in bytes to transfer multiple weights at one time.
42+
device (str): The device to use for the checkpoint engine, "cpu" or "cuda".
43+
rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers.
44+
device_name (str): Mooncake device name filter.
45+
"""
46+
47+
def __init__(
48+
self,
49+
bucket_size: int,
50+
device: str = "cuda",
51+
rollout_dtype: torch.dtype = torch.bfloat16,
52+
device_name: str = "",
53+
is_master: bool = False,
54+
rebuild_group: bool = False,
55+
):
56+
self.bucket_size = bucket_size
57+
self.device = device
58+
self.rollout_dtype = rollout_dtype
59+
self.is_master = is_master
60+
self.rebuild_group = rebuild_group
61+
62+
rank = int(os.environ["RANK"])
63+
device_count = get_torch_device().device_count()
64+
local_rank = rank % device_count
65+
get_torch_device().set_device(local_rank)
66+
67+
self.engine = TransferEngine()
68+
hostname = ray.util.get_node_ip_address().strip("[]")
69+
ret = self.engine.initialize(
70+
hostname,
71+
"P2PHANDSHAKE",
72+
"ascend_direct" if self.device == "npu" else "rdma",
73+
device_name,
74+
)
75+
assert ret == 0, f"TransferEngine initialize failed ret={ret}"
76+
77+
rpc_port = self.engine.get_rpc_port()
78+
self.session_id = f"{hostname}:{rpc_port}"
79+
self.hostname = hostname
80+
81+
self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
82+
assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed"
83+
84+
def prepare(self) -> dict[str, Any]:
85+
"""Prepare send and recv buckets"""
86+
# self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
87+
# self.engine.register_memory(self.buf.data_ptr(), self.bucket_size)
88+
port, _ = get_free_port(self.hostname)
89+
return {"addr": self.hostname, "port": port}
90+
91+
@classmethod
92+
def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadatas: list[dict]):
93+
trainer_kwargs = {
94+
"rank": [0] + [-1] * (trainer_world_size - 1),
95+
"world_size": [rollout_world_size + 1] * trainer_world_size,
96+
"metadata": [metadatas[0]] * trainer_world_size,
97+
}
98+
rollout_kwargs = {
99+
"rank": list(range(1, rollout_world_size + 1)),
100+
"world_size": [rollout_world_size + 1] * rollout_world_size,
101+
"metadata": [metadatas[0]] * rollout_world_size,
102+
}
103+
return trainer_kwargs, rollout_kwargs
104+
105+
def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any]):
106+
self.rank = rank
107+
self.world_size = world_size
108+
if rank < 0:
109+
return
110+
111+
self.store = StatelessProcessGroup.create(
112+
host=metadata["addr"],
113+
port=metadata["port"],
114+
rank=rank,
115+
world_size=world_size,
116+
)
117+
118+
if self.is_master:
119+
buffer_info = {
120+
"session_id": self.session_id,
121+
"ptr": self.buf.data_ptr(),
122+
"len": self.bucket_size,
123+
}
124+
self.store.broadcast_obj(obj=buffer_info, src=0)
125+
else:
126+
self.buffer_info = self.store.broadcast_obj(obj=None, src=0)
127+
128+
129+
def finalize(self):
130+
"""Cleanup communication and deregister memory"""
131+
self.store = None
132+
get_torch_device().empty_cache()
133+
gc.collect()
134+
135+
async def wait_for_complete(self):
136+
magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
137+
target = magic.repeat(self.world_size - 1)
138+
while True:
139+
if torch.equal(self.buf[4:4 * self.world_size], target):
140+
break
141+
await asyncio.sleep(0)
142+
143+
@torch.no_grad()
144+
async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
145+
"""Send weights using Mooncake TransferEngine"""
146+
start_time = time.time()
147+
bucket_meta: dict[str, TensorMeta] = {}
148+
offset = 0
149+
150+
for name, weight in weights:
151+
if self.rank != 0:
152+
continue
153+
weight = weight.to(self.rollout_dtype)
154+
155+
if offset + weight.nbytes > self.bucket_size:
156+
get_torch_device().synchronize
157+
info = {
158+
"bucket_meta": bucket_meta,
159+
"len": offset,
160+
"is_last": False,
161+
}
162+
self.store.broadcast_obj(obj=info, src=0)
163+
await self.wait_for_complete()
164+
bucket_meta = {}
165+
offset = 0
166+
167+
assert offset + weight.nbytes <= self.bucket_size, (
168+
f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket."
169+
)
170+
171+
bucket_meta[name] = {
172+
"name": name,
173+
"shape": weight.shape,
174+
"dtype": weight.dtype,
175+
"offset": offset,
176+
}
177+
self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
178+
offset += weight.nbytes
179+
180+
if self.rank != 0:
181+
return
182+
183+
get_torch_device().synchronize()
184+
info = {
185+
"bucket_meta": bucket_meta,
186+
"len": offset,
187+
"is_last": True,
188+
}
189+
self.store.broadcast_obj(obj=info, src=0)
190+
await self.wait_for_complete()
191+
logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s")
192+
193+
@torch.no_grad()
194+
async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
195+
"""Receive weights using Mooncake TransferEngine"""
196+
start_time = time.time()
197+
total_bytes = 0
198+
while True:
199+
info = self.store.broadcast_obj(obj=None, src=0)
200+
ret = self.engine.transfer_sync_read(
201+
self.buffer_info["session_id"],
202+
self.buf.data_ptr(),
203+
self.buffer_info["ptr"],
204+
info["len"],
205+
)
206+
assert ret == 0, f"transfer_sync_read failed {ret}"
207+
total_bytes += info["len"]
208+
for name, meta in info["bucket_meta"].items():
209+
dtype, shape = meta["dtype"], meta["shape"]
210+
size = dtype.itemsize * shape.numel()
211+
tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape)
212+
yield name, tensor
213+
214+
self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
215+
216+
offset = self.buffer_info["ptr"] + self.rank * 4
217+
ret = self.engine.transfer_sync_write(
218+
self.buffer_info["session_id"],
219+
self.buf.data_ptr(),
220+
offset,
221+
4,
222+
)
223+
assert ret == 0, f"transfer_sync_write failed {ret}"
224+
if info["is_last"]:
225+
break
226+
227+
time_cost = time.time() - start_time
228+
bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024)
229+
logger.info(
230+
f"Rank {self.rank} receive weights done, "
231+
f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s"
232+
)

0 commit comments

Comments
 (0)