|
| 1 | +# Copyright 2024 Bytedance Ltd. and/or its affiliates |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import asyncio |
| 15 | +import logging |
| 16 | +import os |
| 17 | +import time |
| 18 | +import gc |
| 19 | +from collections import defaultdict |
| 20 | +from dataclasses import dataclass |
| 21 | +from typing import Any, AsyncGenerator, Generator |
| 22 | + |
| 23 | +import ray |
| 24 | +import torch |
| 25 | +from vllm.distributed.utils import StatelessProcessGroup |
| 26 | +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta |
| 27 | +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address |
| 28 | +from verl.utils.device import get_torch_device |
| 29 | + |
| 30 | +from mooncake.engine import TransferEngine |
| 31 | + |
| 32 | +logger = logging.getLogger(__name__) |
| 33 | +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) |
| 34 | + |
| 35 | + |
| 36 | +@CheckpointEngineRegistry.register("mooncake") |
| 37 | +class MooncakeCheckpointEngine(CheckpointEngine): |
| 38 | + """Mooncake checkpoint engine with p2p communication using TransferEngine |
| 39 | +
|
| 40 | + Args: |
| 41 | + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. |
| 42 | + device (str): The device to use for the checkpoint engine, "cpu" or "cuda". |
| 43 | + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. |
| 44 | + device_name (str): Mooncake device name filter. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + bucket_size: int, |
| 50 | + device: str = "cuda", |
| 51 | + rollout_dtype: torch.dtype = torch.bfloat16, |
| 52 | + device_name: str = "", |
| 53 | + is_master: bool = False, |
| 54 | + rebuild_group: bool = False, |
| 55 | + ): |
| 56 | + self.bucket_size = bucket_size |
| 57 | + self.device = device |
| 58 | + self.rollout_dtype = rollout_dtype |
| 59 | + self.is_master = is_master |
| 60 | + self.rebuild_group = rebuild_group |
| 61 | + |
| 62 | + rank = int(os.environ["RANK"]) |
| 63 | + device_count = get_torch_device().device_count() |
| 64 | + local_rank = rank % device_count |
| 65 | + get_torch_device().set_device(local_rank) |
| 66 | + |
| 67 | + self.engine = TransferEngine() |
| 68 | + hostname = ray.util.get_node_ip_address().strip("[]") |
| 69 | + ret = self.engine.initialize( |
| 70 | + hostname, |
| 71 | + "P2PHANDSHAKE", |
| 72 | + "ascend_direct" if self.device == "npu" else "rdma", |
| 73 | + device_name, |
| 74 | + ) |
| 75 | + assert ret == 0, f"TransferEngine initialize failed ret={ret}" |
| 76 | + |
| 77 | + rpc_port = self.engine.get_rpc_port() |
| 78 | + self.session_id = f"{hostname}:{rpc_port}" |
| 79 | + self.hostname = hostname |
| 80 | + |
| 81 | + self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) |
| 82 | + assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed" |
| 83 | + |
| 84 | + def prepare(self) -> dict[str, Any]: |
| 85 | + """Prepare send and recv buckets""" |
| 86 | + # self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) |
| 87 | + # self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) |
| 88 | + port, _ = get_free_port(self.hostname) |
| 89 | + return {"addr": self.hostname, "port": port} |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadatas: list[dict]): |
| 93 | + trainer_kwargs = { |
| 94 | + "rank": [0] + [-1] * (trainer_world_size - 1), |
| 95 | + "world_size": [rollout_world_size + 1] * trainer_world_size, |
| 96 | + "metadata": [metadatas[0]] * trainer_world_size, |
| 97 | + } |
| 98 | + rollout_kwargs = { |
| 99 | + "rank": list(range(1, rollout_world_size + 1)), |
| 100 | + "world_size": [rollout_world_size + 1] * rollout_world_size, |
| 101 | + "metadata": [metadatas[0]] * rollout_world_size, |
| 102 | + } |
| 103 | + return trainer_kwargs, rollout_kwargs |
| 104 | + |
| 105 | + def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any]): |
| 106 | + self.rank = rank |
| 107 | + self.world_size = world_size |
| 108 | + if rank < 0: |
| 109 | + return |
| 110 | + |
| 111 | + self.store = StatelessProcessGroup.create( |
| 112 | + host=metadata["addr"], |
| 113 | + port=metadata["port"], |
| 114 | + rank=rank, |
| 115 | + world_size=world_size, |
| 116 | + ) |
| 117 | + |
| 118 | + if self.is_master: |
| 119 | + buffer_info = { |
| 120 | + "session_id": self.session_id, |
| 121 | + "ptr": self.buf.data_ptr(), |
| 122 | + "len": self.bucket_size, |
| 123 | + } |
| 124 | + self.store.broadcast_obj(obj=buffer_info, src=0) |
| 125 | + else: |
| 126 | + self.buffer_info = self.store.broadcast_obj(obj=None, src=0) |
| 127 | + |
| 128 | + |
| 129 | + def finalize(self): |
| 130 | + """Cleanup communication and deregister memory""" |
| 131 | + self.store = None |
| 132 | + get_torch_device().empty_cache() |
| 133 | + gc.collect() |
| 134 | + |
| 135 | + async def wait_for_complete(self): |
| 136 | + magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) |
| 137 | + target = magic.repeat(self.world_size - 1) |
| 138 | + while True: |
| 139 | + if torch.equal(self.buf[4:4 * self.world_size], target): |
| 140 | + break |
| 141 | + await asyncio.sleep(0) |
| 142 | + |
| 143 | + @torch.no_grad() |
| 144 | + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| 145 | + """Send weights using Mooncake TransferEngine""" |
| 146 | + start_time = time.time() |
| 147 | + bucket_meta: dict[str, TensorMeta] = {} |
| 148 | + offset = 0 |
| 149 | + |
| 150 | + for name, weight in weights: |
| 151 | + if self.rank != 0: |
| 152 | + continue |
| 153 | + weight = weight.to(self.rollout_dtype) |
| 154 | + |
| 155 | + if offset + weight.nbytes > self.bucket_size: |
| 156 | + get_torch_device().synchronize |
| 157 | + info = { |
| 158 | + "bucket_meta": bucket_meta, |
| 159 | + "len": offset, |
| 160 | + "is_last": False, |
| 161 | + } |
| 162 | + self.store.broadcast_obj(obj=info, src=0) |
| 163 | + await self.wait_for_complete() |
| 164 | + bucket_meta = {} |
| 165 | + offset = 0 |
| 166 | + |
| 167 | + assert offset + weight.nbytes <= self.bucket_size, ( |
| 168 | + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." |
| 169 | + ) |
| 170 | + |
| 171 | + bucket_meta[name] = { |
| 172 | + "name": name, |
| 173 | + "shape": weight.shape, |
| 174 | + "dtype": weight.dtype, |
| 175 | + "offset": offset, |
| 176 | + } |
| 177 | + self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) |
| 178 | + offset += weight.nbytes |
| 179 | + |
| 180 | + if self.rank != 0: |
| 181 | + return |
| 182 | + |
| 183 | + get_torch_device().synchronize() |
| 184 | + info = { |
| 185 | + "bucket_meta": bucket_meta, |
| 186 | + "len": offset, |
| 187 | + "is_last": True, |
| 188 | + } |
| 189 | + self.store.broadcast_obj(obj=info, src=0) |
| 190 | + await self.wait_for_complete() |
| 191 | + logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s") |
| 192 | + |
| 193 | + @torch.no_grad() |
| 194 | + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: |
| 195 | + """Receive weights using Mooncake TransferEngine""" |
| 196 | + start_time = time.time() |
| 197 | + total_bytes = 0 |
| 198 | + while True: |
| 199 | + info = self.store.broadcast_obj(obj=None, src=0) |
| 200 | + ret = self.engine.transfer_sync_read( |
| 201 | + self.buffer_info["session_id"], |
| 202 | + self.buf.data_ptr(), |
| 203 | + self.buffer_info["ptr"], |
| 204 | + info["len"], |
| 205 | + ) |
| 206 | + assert ret == 0, f"transfer_sync_read failed {ret}" |
| 207 | + total_bytes += info["len"] |
| 208 | + for name, meta in info["bucket_meta"].items(): |
| 209 | + dtype, shape = meta["dtype"], meta["shape"] |
| 210 | + size = dtype.itemsize * shape.numel() |
| 211 | + tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| 212 | + yield name, tensor |
| 213 | + |
| 214 | + self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) |
| 215 | + |
| 216 | + offset = self.buffer_info["ptr"] + self.rank * 4 |
| 217 | + ret = self.engine.transfer_sync_write( |
| 218 | + self.buffer_info["session_id"], |
| 219 | + self.buf.data_ptr(), |
| 220 | + offset, |
| 221 | + 4, |
| 222 | + ) |
| 223 | + assert ret == 0, f"transfer_sync_write failed {ret}" |
| 224 | + if info["is_last"]: |
| 225 | + break |
| 226 | + |
| 227 | + time_cost = time.time() - start_time |
| 228 | + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) |
| 229 | + logger.info( |
| 230 | + f"Rank {self.rank} receive weights done, " |
| 231 | + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" |
| 232 | + ) |
0 commit comments