diff --git a/checkpoint_engine/__init__.py b/checkpoint_engine/__init__.py index 3c24562..28bc270 100644 --- a/checkpoint_engine/__init__.py +++ b/checkpoint_engine/__init__.py @@ -2,3 +2,39 @@ from ._version import __version__ except ImportError: __version__ = "dev" + +from .api import request_inference_to_update +from .data_types import ( + BucketRange, + DataToGather, + H2DBucket, + MemoryBuffer, + MemoryBufferMetaList, + MemoryBufferMetas, + ParameterMeta, +) +from .device_utils import DeviceManager, get_ip, npu_generate_uuid +from .p2p_store import P2PStore +from .ps import ParameterServer +from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc + + +__all__ = [ + "BucketRange", + "DataToGather", + "DeviceManager", + "FlattenedTensorMetadata", + "H2DBucket", + "MemoryBuffer", + "MemoryBufferMetaList", + "MemoryBufferMetas", + "P2PStore", + "ParameterMeta", + "ParameterServer", + "VllmColocateWorkerExtension", + "__version__", + "get_ip", + "npu_generate_uuid", + "request_inference_to_update", + "update_weights_from_ipc", +] diff --git a/checkpoint_engine/__main__.py b/checkpoint_engine/__main__.py new file mode 100644 index 0000000..3f155dc --- /dev/null +++ b/checkpoint_engine/__main__.py @@ -0,0 +1,28 @@ +import argparse +import os + +from loguru import logger + +from checkpoint_engine.api import _init_api +from checkpoint_engine.ps import ParameterServer + + +@logger.catch(reraise=True) +def run_from_cli(): + import uvicorn + + parser = argparse.ArgumentParser(description="Parameter Server") + parser.add_argument("--uds", type=str) + + args = parser.parse_args() + logger.info( + f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}" + ) + + assert args.uds and len(args.uds) > 0, args.uds + ps = ParameterServer(auto_pg=True) + uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60) + + +if __name__ == "__main__": + run_from_cli() diff --git a/checkpoint_engine/api.py b/checkpoint_engine/api.py new file mode 100644 index 0000000..e61b41d --- /dev/null +++ b/checkpoint_engine/api.py @@ -0,0 +1,95 @@ +from collections.abc import Callable +from typing import Any + +import fastapi +import httpx +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from loguru import logger +from pydantic import BaseModel + +from checkpoint_engine.ps import ParameterServer + + +def request_inference_to_update( + url: str, + socket_paths: dict[str, str], + timeout: float = 300.0, + uds: str | None = None, +): + """Send an inference update request to inference server via HTTP or Unix socket. + + Args: + url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to. + socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights. + timeout (float, optional): Request timeout in seconds. Defaults to 300.0. + uds (str, optional): Path to a Unix domain socket. If provided, the request + will be sent via the Unix socket instead of HTTP. Defaults to None. + + Raises: + httpx.HTTPStatusError: If the response contains an HTTP error status. + httpx.RequestError: If there was an issue while making the request. + """ + resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post( + url, + json={ + "method": "update_weights_from_ipc", + "args": [socket_paths], + "timeout": timeout, + }, + timeout=timeout, + ) + resp.raise_for_status() + + +def _init_api(ps: ParameterServer) -> Any: + app = fastapi.FastAPI() + + class RegisterRequest(BaseModel): + files: list[str] + + class UpdateRequest(BaseModel): + ranks: list[int] = [] + update_url: str | None = None + inference_group_ranks: list[int] = [] + timeout: float = 300.0 + uds: str | None = None + + def wrap_exception(func: Callable[[], None]) -> Response: + try: + func() + except Exception as e: # noqa: BLE001 + logger.exception(f"wrap exception {func} failed") + return JSONResponse(content=str(e), status_code=500) + return Response(status_code=200) + + @app.post("/v1/checkpoints/{checkpoint_name}/files") + async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response: + return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files)) + + @app.delete("/v1/checkpoints/{checkpoint_name}") + async def unregister_checkpoint(checkpoint_name: str) -> Response: + return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name)) + + @app.get("/v1/healthz") + async def healthz() -> Response: + return Response(status_code=200) + + @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas") + async def gather_metas(checkpoint_name: str) -> Response: + return wrap_exception(lambda: ps.gather_metas(checkpoint_name)) + + @app.post("/v1/checkpoints/{checkpoint_name}/update") + async def update(checkpoint_name: str, req: UpdateRequest) -> Response: + def update_func(socket_paths: list[tuple[str, str]]): + if req.update_url is None: + return + if req.inference_group_ranks: + socket_paths = [socket_paths[i] for i in req.inference_group_ranks] + request_inference_to_update( + req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds + ) + + return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks)) + + return app diff --git a/checkpoint_engine/data_types.py b/checkpoint_engine/data_types.py new file mode 100644 index 0000000..7cb90db --- /dev/null +++ b/checkpoint_engine/data_types.py @@ -0,0 +1,111 @@ +from typing import TYPE_CHECKING, Annotated, Any, NamedTuple + +import torch +from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema + + +if TYPE_CHECKING: + from typing import TypeVar + + from typing_extensions import TypedDict + + class FileMeta(TypedDict): + key: str # parameter name + dtype: torch.dtype + shape: torch.Size + type: type + tp_concat_dim: int + + T = TypeVar("T") + + +def _dt_validate(value: Any) -> torch.dtype: + if isinstance(value, str): + if not value.startswith("torch."): + raise ValueError(f"dtype {value} should start with torch.") + try: + value = getattr(torch, value.split(".")[1]) + except AttributeError as e: + raise ValueError(f"unknown dtype: {value}") from e + if not isinstance(value, torch.dtype): + raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}") + return value + + +_TorchDtype = Annotated[ + torch.dtype, + PlainValidator(_dt_validate), + PlainSerializer(lambda x: str(x), return_type=str), + WithJsonSchema({"type": "string"}, mode="serialization"), +] + + +def _size_validate(value: Any) -> torch.Size: + if isinstance(value, list | tuple): + return torch.Size(value) + if not isinstance(value, torch.Size): + raise TypeError(f"size {value} should be torch.Size, got {type(value)}") + return value + + +_TorchSize = Annotated[ + torch.Size, + PlainValidator(_size_validate), + PlainSerializer(lambda x: tuple(x), return_type=tuple), + WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"), +] + + +def _tensor_validate(value: Any) -> torch.Tensor: + if isinstance(value, torch.Tensor): + return value + raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}") + + +_TorchTensor = Annotated[ + torch.Tensor, + PlainValidator(_tensor_validate), +] + + +class ParameterMeta(BaseModel): + name: str + dtype: _TorchDtype + shape: _TorchSize + aligned_size: int + + +class BucketRange(NamedTuple): + idx: int # bucket_idx of MemoryBucket in memory_pool + offset: int + size: int + + +class H2DBucket(BaseModel): + size: int + ranges: list[BucketRange] + items: list[ParameterMeta] + + +class MemoryBufferMetas(BaseModel): + metas: list[ParameterMeta] + ptr: int + size: int + + +class MemoryBuffer(BaseModel): + buffer: _TorchTensor + size: int + metas: list[ParameterMeta] + manually_pinned: bool = False + + +class MemoryBufferMetaList(BaseModel): + p2p_store_addr: str | None + memory_buffer_metas_list: list[MemoryBufferMetas] + rdma_device: str + + +class DataToGather(MemoryBufferMetaList): + host_ip: str + device_uuid: str diff --git a/checkpoint_engine/p2p_store.py b/checkpoint_engine/p2p_store.py new file mode 100644 index 0000000..b4be02a --- /dev/null +++ b/checkpoint_engine/p2p_store.py @@ -0,0 +1,210 @@ +import ctypes +import os +import random +import time + +import torch +from loguru import logger + +from checkpoint_engine.device_utils import DeviceManager, get_ip + + +def _ibv_get_device_list() -> list[str]: + lib = ctypes.CDLL("libibverbs.so.1") + lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices + lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** + + lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * + lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * + + num = ctypes.c_int() + dev_array = lib.ibv_get_device_list(ctypes.byref(num)) + if not dev_array or num.value <= 0: + return [] + + devices = [] + for i in range(num.value): + dev_ptr = dev_array[i] # struct ibv_device * + name = lib.ibv_get_device_name(dev_ptr) # const char * + devices.append(name.decode()) + lib.ibv_free_device_list(dev_array) + return devices + + +def _get_rdma_devices() -> list[str]: + """ + use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return + """ + devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") + if devices_str: + return devices_str.split(",") + # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices + hca = os.getenv("NCCL_IB_HCA", None) + return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list() + + +def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: + """ + implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc. + """ + if not devices: + raise RuntimeError("no rdma devices found") + try: + assert len(devices) <= gpu_count, ( + f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" + ) + assert gpu_count % len(devices) == 0, ( + f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" + ) + return devices[local_rank // (gpu_count // len(devices))] + except AssertionError: + logger.error( + "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices." + "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices." + "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'." + ) + raise + + +def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: + """ + The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8. + The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662. + + The list is comma-separated; port numbers are NOT supported yet. + An optional prefix '^' indicates the list is an exclude list. + A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix. + Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. + + Examples: + - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`. + - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`. + - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`. + - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`. + """ + max_hcas = 32 + if not value or value.strip() == "": + return available_devices[:max_hcas] + + value = value.strip() + result = [] + is_exclude = value.startswith("^") + if is_exclude: + value = value.removeprefix("^") + is_exact_match = value.startswith("=") + if is_exact_match: + value = value.removeprefix("=") + + device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] + + result = _resolve_device_specs(device_specs, is_exact_match, available_devices) + if is_exclude: + result = [dev for dev in available_devices if dev not in result] + if len(result) > max_hcas: + result = result[:max_hcas] + + logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}") + + return result + + +def _resolve_device_specs( + device_specs: list[str], is_exact_match: bool, available_devices: list[str] +) -> list[str]: + devices = set() + for spec in device_specs: + parts = spec.split(":", 1) + device_name = parts[0].strip() + # HACK: mooncake transfer engine does not support port specification yet, so we ignore it + # port = parts[1].strip() if len(parts) > 1 else None + base_devices = ( + [device_name] + if device_name in available_devices + else [] + if is_exact_match + else [dev for dev in available_devices if dev.startswith(device_name)] + ) + + if not base_devices: + logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.") + continue + + for base_dev in base_devices: + devices.add(base_dev) + + return sorted(devices) + + +class P2PStore: + def __init__(self, device_manager: DeviceManager): + from mooncake.engine import TransferEngine + + self.rank = int(os.getenv("RANK")) + gpu_count = device_manager.device_module.device_count() + local_rank = self.rank % gpu_count + device_type = device_manager.device_type + if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None: + self.device = "" + else: + self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) + self.ip = get_ip() + + # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases + retry_count = 8 + for i in range(retry_count): + self.engine = TransferEngine() + ret = self.engine.initialize( + self.ip, + "P2PHANDSHAKE", + "ascend_direct" if device_type == "npu" else "rdma", + self.device, + ) + if ret == 0: + break + # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time + sleep_ms = random.randint(500, 2000) + logger.warning( + f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms" + ) + time.sleep(sleep_ms / 1000) + else: + raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine") + self.port = self.engine.get_rpc_port() + self.named_tensors: dict[str, torch.Tensor] = {} + logger.info( + f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}" + ) + + @property + def addr(self) -> str: + return f"{self.ip}:{self.port}" + + def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]): + buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()] + capacities = [tensor.nbytes for tensor in named_tensors.values()] + self.named_tensors.update(named_tensors) + for i, name in enumerate(named_tensors.keys()): + logger.info( + f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}" + ) + assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0 + + def unregister_named_tensors(self, names: list[str]) -> int: + buffer_addresses = [self.named_tensors[name].data_ptr() for name in names] + assert self.engine.batch_unregister_memory(buffer_addresses) == 0 + num_unregistered = 0 + for i, name in enumerate(names): + del self.named_tensors[name] + logger.info( + f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}" + ) + num_unregistered += 1 + return num_unregistered + + def batch_transfer_sync_read( + self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int] + ): + assert ( + self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0 + ) diff --git a/checkpoint_engine/pin_memory.py b/checkpoint_engine/pin_memory.py new file mode 100644 index 0000000..3edcb12 --- /dev/null +++ b/checkpoint_engine/pin_memory.py @@ -0,0 +1,390 @@ +import concurrent.futures +import json +import os +import pickle +from typing import TYPE_CHECKING, Any, BinaryIO + +import numpy as np +import torch +from loguru import logger +from pydantic import BaseModel +from safetensors.torch import _getdtype, safe_open + +from checkpoint_engine.data_types import ( + MemoryBuffer, + ParameterMeta, +) + + +if TYPE_CHECKING: + from checkpoint_engine.data_types import FileMeta + +# 256 bytes alignment when flatten torch tensors to uint8 buffer +_ALIGN_SIZE = 256 + + +def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: + return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE + + +def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]: + def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: + ret = {} + with safe_open(fn, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + weight = f.get_tensor(name) + meta = { + "key": name, + "dtype": weight.dtype, + "shape": weight.shape, + "type": type(weight), + "tp_concat_dim": -1, # safetensors does not support tp_concat_dim + } + ret[name] = (meta, weight) + return ret + + # deprecated, will be removed in the future + def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: + """load *.np file and return memmap and related tensor meta""" + + def parse_npy_header(fin: BinaryIO) -> dict[str, Any]: + start = fin.tell() + major, minor = np.lib.format.read_magic(fin) + if major == 1 and minor == 0: + read_header_fn = np.lib.format.read_array_header_1_0 + elif major == 2 and minor == 0: + read_header_fn = np.lib.format.read_array_header_2_0 + else: + raise ValueError( + f"unknown version {major}.{minor} when parsing npy header from {fn}" + ) + shape, is_fortran, dtype = read_header_fn(fin) + return { + "shape": shape, + "is_fortran": is_fortran, + "dtype": dtype, + "header_length": fin.tell() - start, + } + + meta_fn = fn + ".meta" + with open(meta_fn, "rb") as fin: + meta_lst = pickle.load(fin) + + tensors = [] + offset = 0 + with open(fn, "rb") as fin: + fin.seek(0, os.SEEK_END) + filesize = fin.tell() + fin.seek(0) + while fin.tell() < filesize: + tensor_meta = parse_npy_header(fin) + tensor = np.memmap( + fn, + dtype=tensor_meta["dtype"], + mode="c", + offset=offset + tensor_meta["header_length"], + shape=tensor_meta["shape"], + ) + offset += tensor_meta["header_length"] + tensor.nbytes + fin.seek(offset) + tensors.append(tensor) + + assert len(meta_lst) == len(tensors) + ret = {} + for meta, tensor in zip(meta_lst, tensors): + if meta["type"] == torch.Tensor: + tensor = torch.from_numpy(tensor) + tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"]) + ret[meta["key"]] = (meta, tensor) + return ret + + tp_rank = 0 + if file_path.endswith(".npy"): + logger.warning("numpy model file is deprecated, will be removed in the future") + filename_split = os.path.basename(file_path).split(".") + # if using numpy and want to specify tp rank + # file should be in model.{layer}.{tp}[.{ep}].npy format + tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0 + ret = _fast_np_load(file_path) + elif file_path.endswith(".safetensors"): + ret = _safetensors_load(file_path) + else: + raise ValueError(f"unsupported file format: {file_path}") + return tp_rank, ret + + +def _concat_tp_weights( + tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int +) -> torch.Tensor: + """Concat tp weights with meta info. + If meta.concat_dim is -1, means this is shared tp weights, just use the first weights. + Else we will cat weights in concat_dim. + """ + if tp_concat_dim == -1: + return tp_weights[0] + assert tp_size == len(tp_weights) + if len(tp_weights) == 1: + return tp_weights[0] + return torch.cat([w for w in tp_weights], dim=tp_concat_dim) + + +def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]: + class TPMeta(BaseModel): + concat_dim: int + size: int + + parameters: dict[str, torch.Tensor] = {} + parameter_metas: dict[str, ParameterMeta] = {} + tp_metas: dict[str, TPMeta] = {} + parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {} + for file in files: + tp_rank, ret = _load_checkpoint_file(file) + for parameter_name, (meta, weight) in ret.items(): + if parameter_name not in parameters_with_tp: + parameters_with_tp[parameter_name] = {} + parameters_with_tp[parameter_name][tp_rank] = weight + if parameter_name not in tp_metas: + tp_metas[parameter_name] = TPMeta( + concat_dim=meta["tp_concat_dim"], + size=1, + ) + if parameter_name not in parameter_metas: + assert isinstance(meta["dtype"], torch.dtype), ( + f"meta {meta} dtype should be torch.dtype" + ) + assert isinstance(meta["shape"], torch.Size), ( + f"meta {meta} shape should be torch.Size" + ) + parameter_metas[parameter_name] = ParameterMeta( + name=parameter_name, + shape=meta["shape"], + dtype=meta["dtype"], + aligned_size=_align_size(meta["dtype"], meta["shape"]), + ) + tp_meta = tp_metas[parameter_name] + if tp_meta.concat_dim != -1: + tp_meta.size = max(tp_meta.size, tp_rank + 1) + for name, tp_meta in tp_metas.items(): + if tp_meta.concat_dim != -1: + shape = list(parameter_metas[name].shape) + shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size + parameter_metas[name] = ParameterMeta( + name=name, + shape=torch.Size(shape), + dtype=parameter_metas[name].dtype, + aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)), + ) + weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])] + # TODO: here concat is serial, which may be slow + # but since tp storage is not used in the future + # we ignore this performance issue for now + parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size) + for name, parameter in parameters.items(): + assert name in parameter_metas, f"parameter {name} not found in parameter_metas" + assert parameter_metas[name].shape == parameter.shape, ( + f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}" + ) + assert parameter_metas[name].dtype == parameter.dtype, ( + f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}" + ) + return parameters + + +def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]: + def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer: + """ + safetensors format see https://huggingface.co/docs/safetensors/en/index#format. + We load the safetensors file as bytes, then parse the header manually to get parameter metas. + The actual tensor data is in the remaining bytes and is naturally aligned. + We pin the remaining bytes as the buffer, making pinning faster. + """ + + def _pin(t: torch.Tensor): + """ + Pin the memory of tensor in-place. + See: https://github.com/pytorch/pytorch/issues/32167 + """ + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert r == 0, f"pin memory error, error code: {r}" + + # TODO: should only support /dev/shm? but we found files in disk also work? + size = os.stat(file_path).st_size + flag_size = 8 + t = torch.from_file(file_path, True, size, dtype=torch.uint8) + assert t.nbytes > flag_size, ( + f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" + ) + start_pos = ( + int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) + + flag_size + ) + header_tensor = t[flag_size:start_pos] + header = json.loads(header_tensor.numpy().tobytes()) + if "__metadata__" in header: + header.pop("__metadata__") + + metas: list[ParameterMeta] = [] + offset = 0 + try: + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): + start, end = meta["data_offsets"] + # safetensors format ensures offsets are aligned + assert offset == start, f"offset {offset} should be equal to start {start}" + metas.append( + ParameterMeta( + name=name, + dtype=_getdtype(meta["dtype"]), + shape=torch.Size(meta["shape"]), + aligned_size=end - start, + ) + ) + offset = end + except Exception as e: + logger.error(f"fail to parse safetensors header from {file_path}: {e}") + raise + + buffer = t[start_pos:] + assert offset == buffer.nbytes, ( + f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" + ) + # Remove the file after successfully loading. This will avoid doubling the memory usage. + # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading. + os.remove(file_path) + _pin(buffer) + logger.info( + f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB" + ) + return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True) + + memory_buffers: list[MemoryBuffer] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) + return memory_buffers + + +def _normal_pin_memory( + files: list[str], + named_tensors: dict[str, torch.Tensor], + rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, +) -> list[MemoryBuffer]: + parameters = _load_checkpoint(files) + if named_tensors: + parameters.update(named_tensors) + bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) + + class MemoryBucket(BaseModel): + size: int + metas: list[ParameterMeta] + + buckets: list[MemoryBucket] = [] + buckets.append(MemoryBucket(size=0, metas=[])) + for name, tensor in sorted(parameters.items()): + size = _align_size(tensor.dtype, tensor.shape) + if buckets[-1].size + size > bucket_size: + assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" + buckets.append(MemoryBucket(size=0, metas=[])) + buckets[-1].metas.append( + ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size) + ) + buckets[-1].size += size + + memory_buffers = [ + MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) + for bucket in buckets + ] + + def register_pin_memory( + idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None + ) -> tuple[int, torch.Tensor]: + if shared_pin_memory: + # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one + # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time + assert idx < len(shared_pin_memory), ( + f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" + ) + assert shared_pin_memory[idx].size == size, ( + f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" + ) + return idx, shared_pin_memory[idx].buffer + else: + buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) + return idx, buffer + + def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): + buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + futures = [ + executor.submit( + register_pin_memory, + idx, + bucket.size, + shared_pin_memory, + ) + for idx, bucket in enumerate(buckets) + ] + new_futures = [] + for future in concurrent.futures.as_completed(futures): + idx, buffer = future.result() + assert buffer.numel() == buckets[idx].size, ( + f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" + ) + memory_buffers[idx].buffer = buffer + logger.info( + f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " + f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" + ) + offset = 0 + for meta in buckets[idx].metas: + name = meta.name + tensor = parameters[name] + size = _align_size(tensor.dtype, tensor.shape) + assert size == _align_size(meta.dtype, meta.shape), ( + f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}" + ) + new_futures.append(executor.submit(register_tensor, buffer, offset, tensor)) + offset += size + for future in concurrent.futures.as_completed(new_futures): + future.result() + return memory_buffers + + +def _register_checkpoint( + *, + files: list[str], + named_tensors: dict[str, torch.Tensor], + rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, + inplace_pin: bool = False, +) -> list[MemoryBuffer]: + logger.info( + f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" + ) + if not files and not named_tensors: + return [] + memory_buffers: list[MemoryBuffer] = [] + if inplace_pin: + logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files") + files_to_inplace_pin = [ + file + for file in files + if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 + ] + files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] + else: + files_to_normal_pin = files + files_to_inplace_pin = [] + if files_to_normal_pin or named_tensors: + memory_buffers.extend( + _normal_pin_memory( + files=files_to_normal_pin, + named_tensors=named_tensors, + rank=rank, + shared_pin_memory=shared_pin_memory, + ) + ) + if files_to_inplace_pin: + memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank)) + return memory_buffers diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 467cbd3..778d588 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -1,143 +1,33 @@ -import argparse -import concurrent.futures import ctypes -import json import os -import pickle -import random import threading -import time from collections import defaultdict from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple +from typing import TYPE_CHECKING -import httpx -import numpy as np import torch import torch.distributed as dist import zmq from loguru import logger -from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from safetensors.torch import _getdtype, safe_open from torch.multiprocessing.reductions import reduce_tensor +from checkpoint_engine.data_types import ( + BucketRange, + DataToGather, + H2DBucket, + MemoryBuffer, + MemoryBufferMetaList, + MemoryBufferMetas, + ParameterMeta, +) from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid +from checkpoint_engine.p2p_store import P2PStore +from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint if TYPE_CHECKING: - from typing import TypeVar - - from typing_extensions import TypedDict - - class FileMeta(TypedDict): - key: str # parameter name - dtype: torch.dtype - shape: torch.Size - type: type - tp_concat_dim: int - - T = TypeVar("T") - - -def _dt_validate(value: Any) -> torch.dtype: - if isinstance(value, str): - if not value.startswith("torch."): - raise ValueError(f"dtype {value} should start with torch.") - try: - value = getattr(torch, value.split(".")[1]) - except AttributeError as e: - raise ValueError(f"unknown dtype: {value}") from e - if not isinstance(value, torch.dtype): - raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}") - return value - - -_TorchDtype = Annotated[ - torch.dtype, - PlainValidator(_dt_validate), - PlainSerializer(lambda x: str(x), return_type=str), - WithJsonSchema({"type": "string"}, mode="serialization"), -] - - -def _size_validate(value: Any) -> torch.Size: - if isinstance(value, list | tuple): - return torch.Size(value) - if not isinstance(value, torch.Size): - raise TypeError(f"size {value} should be torch.Size, got {type(value)}") - return value - - -_TorchSize = Annotated[ - torch.Size, - PlainValidator(_size_validate), - PlainSerializer(lambda x: tuple(x), return_type=tuple), - WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"), -] - - -def _tensor_validate(value: Any) -> torch.Tensor: - if isinstance(value, torch.Tensor): - return value - raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}") - - -_TorchTensor = Annotated[ - torch.Tensor, - PlainValidator(_tensor_validate), -] - - -class ParameterMeta(BaseModel): - name: str - dtype: _TorchDtype - shape: _TorchSize - aligned_size: int - - -class BucketRange(NamedTuple): - idx: int # bucket_idx of MemoryBucket in memory_pool - offset: int - size: int - - -class H2DBucket(BaseModel): - size: int - ranges: list[BucketRange] - items: list[ParameterMeta] - - -class MemoryBufferMetas(BaseModel): - metas: list[ParameterMeta] - ptr: int - size: int - - -class MemoryBuffer(BaseModel): - buffer: _TorchTensor - size: int - metas: list[ParameterMeta] - manually_pinned: bool = False - - -class MemoryBufferMetaList(BaseModel): - p2p_store_addr: str | None - memory_buffer_metas_list: list[MemoryBufferMetas] - rdma_device: str - - -class DataToGather(MemoryBufferMetaList): - host_ip: str - device_uuid: str - - -# 256 bytes alignment when flatten torch tensors to uint8 buffer -_ALIGN_SIZE = 256 - - -def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: - return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE + from checkpoint_engine.data_types import T def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: @@ -156,107 +46,6 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: return ret -def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]: - def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: - ret = {} - with safe_open(fn, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - weight = f.get_tensor(name) - meta = { - "key": name, - "dtype": weight.dtype, - "shape": weight.shape, - "type": type(weight), - "tp_concat_dim": -1, # safetensors does not support tp_concat_dim - } - ret[name] = (meta, weight) - return ret - - # deprecated, will be removed in the future - def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: - """load *.np file and return memmap and related tensor meta""" - - def parse_npy_header(fin: BinaryIO) -> dict[str, Any]: - start = fin.tell() - major, minor = np.lib.format.read_magic(fin) - if major == 1 and minor == 0: - read_header_fn = np.lib.format.read_array_header_1_0 - elif major == 2 and minor == 0: - read_header_fn = np.lib.format.read_array_header_2_0 - else: - raise ValueError( - f"unknown version {major}.{minor} when parsing npy header from {fn}" - ) - shape, is_fortran, dtype = read_header_fn(fin) - return { - "shape": shape, - "is_fortran": is_fortran, - "dtype": dtype, - "header_length": fin.tell() - start, - } - - meta_fn = fn + ".meta" - with open(meta_fn, "rb") as fin: - meta_lst = pickle.load(fin) - - tensors = [] - offset = 0 - with open(fn, "rb") as fin: - fin.seek(0, os.SEEK_END) - filesize = fin.tell() - fin.seek(0) - while fin.tell() < filesize: - tensor_meta = parse_npy_header(fin) - tensor = np.memmap( - fn, - dtype=tensor_meta["dtype"], - mode="c", - offset=offset + tensor_meta["header_length"], - shape=tensor_meta["shape"], - ) - offset += tensor_meta["header_length"] + tensor.nbytes - fin.seek(offset) - tensors.append(tensor) - - assert len(meta_lst) == len(tensors) - ret = {} - for meta, tensor in zip(meta_lst, tensors): - if meta["type"] == torch.Tensor: - tensor = torch.from_numpy(tensor) - tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"]) - ret[meta["key"]] = (meta, tensor) - return ret - - tp_rank = 0 - if file_path.endswith(".npy"): - logger.warning("numpy model file is deprecated, will be removed in the future") - filename_split = os.path.basename(file_path).split(".") - # if using numpy and want to specify tp rank - # file should be in model.{layer}.{tp}[.{ep}].npy format - tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0 - ret = _fast_np_load(file_path) - elif file_path.endswith(".safetensors"): - ret = _safetensors_load(file_path) - else: - raise ValueError(f"unsupported file format: {file_path}") - return tp_rank, ret - - -def _concat_tp_weights( - tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int -) -> torch.Tensor: - """Concat tp weights with meta info. - If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights. - Else we will cat weights in concat_dim. - """ - if tp_concat_dim == -1: - return tp_weights[0] - assert tp_size == len(tp_weights) - if len(tp_weights) == 1: - return tp_weights[0] - return torch.cat([w for w in tp_weights], dim=tp_concat_dim) - - def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: try: if device_manager.device_type == "npu": @@ -267,426 +56,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None raise ValueError(f"fail to get physical gpu id {device_index}") from e -def _ibv_get_device_list() -> list[str]: - lib = ctypes.CDLL("libibverbs.so.1") - lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices - lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** - - lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] - lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * - lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * - - num = ctypes.c_int() - dev_array = lib.ibv_get_device_list(ctypes.byref(num)) - if not dev_array or num.value <= 0: - return [] - - devices = [] - for i in range(num.value): - dev_ptr = dev_array[i] # struct ibv_device * - name = lib.ibv_get_device_name(dev_ptr) # const char * - devices.append(name.decode()) - lib.ibv_free_device_list(dev_array) - return devices - - -def _get_rdma_devices() -> list[str]: - """ - use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return - """ - devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") - if devices_str: - return devices_str.split(",") - # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices - hca = os.getenv("NCCL_IB_HCA", None) - return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list() - - -def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: - """ - implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc. - """ - if not devices: - raise RuntimeError("no rdma devices found") - try: - assert len(devices) <= gpu_count, ( - f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" - ) - assert gpu_count % len(devices) == 0, ( - f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" - ) - return devices[local_rank // (gpu_count // len(devices))] - except AssertionError: - logger.error( - "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices." - "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices." - "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'." - ) - raise - - -def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: - """ - The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8. - The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662. - - The list is comma-separated; port numbers are NOT supported yet. - An optional prefix '^' indicates the list is an exclude list. - A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix. - Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. - - Examples: - - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`. - - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`. - - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`. - - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`. - """ - max_hcas = 32 - if not value or value.strip() == "": - return available_devices[:max_hcas] - - value = value.strip() - result = [] - is_exclude = value.startswith("^") - if is_exclude: - value = value.removeprefix("^") - is_exact_match = value.startswith("=") - if is_exact_match: - value = value.removeprefix("=") - - device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] - - result = _resolve_device_specs(device_specs, is_exact_match, available_devices) - if is_exclude: - result = [dev for dev in available_devices if dev not in result] - if len(result) > max_hcas: - result = result[:max_hcas] - - logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}") - - return result - - -def _resolve_device_specs( - device_specs: list[str], is_exact_match: bool, available_devices: list[str] -) -> list[str]: - devices = set() - for spec in device_specs: - parts = spec.split(":", 1) - device_name = parts[0].strip() - # HACK: mooncake transfer engine does not support port specification yet, so we ignore it - # port = parts[1].strip() if len(parts) > 1 else None - base_devices = ( - [device_name] - if device_name in available_devices - else [] - if is_exact_match - else [dev for dev in available_devices if dev.startswith(device_name)] - ) - - if not base_devices: - logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.") - continue - - for base_dev in base_devices: - devices.add(base_dev) - - return sorted(devices) - - -def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]: - class TPMeta(BaseModel): - concat_dim: int - size: int - - parameters: dict[str, torch.Tensor] = {} - parameter_metas: dict[str, ParameterMeta] = {} - tp_metas: dict[str, TPMeta] = {} - parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {} - for file in files: - tp_rank, ret = _load_checkpoint_file(file) - for parameter_name, (meta, weight) in ret.items(): - if parameter_name not in parameters_with_tp: - parameters_with_tp[parameter_name] = {} - parameters_with_tp[parameter_name][tp_rank] = weight - if parameter_name not in tp_metas: - tp_metas[parameter_name] = TPMeta( - concat_dim=meta["tp_concat_dim"], - size=1, - ) - if parameter_name not in parameter_metas: - assert isinstance(meta["dtype"], torch.dtype), ( - f"meta {meta} dtype should be torch.dtype" - ) - assert isinstance(meta["shape"], torch.Size), ( - f"meta {meta} shape should be torch.Size" - ) - parameter_metas[parameter_name] = ParameterMeta( - name=parameter_name, - shape=meta["shape"], - dtype=meta["dtype"], - aligned_size=_align_size(meta["dtype"], meta["shape"]), - ) - tp_meta = tp_metas[parameter_name] - if tp_meta.concat_dim != -1: - tp_meta.size = max(tp_meta.size, tp_rank + 1) - for name, tp_meta in tp_metas.items(): - if tp_meta.concat_dim != -1: - shape = list(parameter_metas[name].shape) - shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size - parameter_metas[name] = ParameterMeta( - name=name, - shape=torch.Size(shape), - dtype=parameter_metas[name].dtype, - aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)), - ) - weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])] - # TODO: here concat is serial, which may be slow - # but since tp storage is not used in the future - # we ignore this performance issue for now - parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size) - for name, parameter in parameters.items(): - assert name in parameter_metas, f"parameter {name} not found in parameter_metas" - assert parameter_metas[name].shape == parameter.shape, ( - f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}" - ) - assert parameter_metas[name].dtype == parameter.dtype, ( - f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}" - ) - return parameters - - -def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]: - def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer: - """ - safetensors format see https://huggingface.co/docs/safetensors/en/index#format. - We load the safetensors file as bytes, then parse the header manually to get parameter metas. - The actual tensor data is in the remaining bytes and is naturally aligned. - We pin the remaining bytes as the buffer, making pinning faster. - """ - - def _pin(t: torch.Tensor): - """ - Pin the memory of tensor in-place. - See: https://github.com/pytorch/pytorch/issues/32167 - """ - cudart = torch.cuda.cudart() - r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) - assert r == 0, f"pin memory error, error code: {r}" - - # TODO: should only support /dev/shm? but we found files in disk also work? - size = os.stat(file_path).st_size - flag_size = 8 - t = torch.from_file(file_path, True, size, dtype=torch.uint8) - assert t.nbytes > flag_size, ( - f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" - ) - start_pos = ( - int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) - + flag_size - ) - header_tensor = t[flag_size:start_pos] - header = json.loads(header_tensor.numpy().tobytes()) - if "__metadata__" in header: - header.pop("__metadata__") - - metas: list[ParameterMeta] = [] - offset = 0 - try: - for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): - start, end = meta["data_offsets"] - # safetensors format ensures offsets are aligned - assert offset == start, f"offset {offset} should be equal to start {start}" - metas.append( - ParameterMeta( - name=name, - dtype=_getdtype(meta["dtype"]), - shape=torch.Size(meta["shape"]), - aligned_size=end - start, - ) - ) - offset = end - except Exception as e: - logger.error(f"fail to parse safetensors header from {file_path}: {e}") - raise - - buffer = t[start_pos:] - assert offset == buffer.nbytes, ( - f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" - ) - # Remove the file after successfully loading. This will avoid doubling the memory usage. - # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading. - os.remove(file_path) - _pin(buffer) - logger.info( - f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB" - ) - return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True) - - memory_buffers: list[MemoryBuffer] = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) - return memory_buffers - - -def _normal_pin_memory( - files: list[str], - named_tensors: dict[str, torch.Tensor], - rank: int | None = None, - shared_pin_memory: list[MemoryBuffer] | None = None, -) -> list[MemoryBuffer]: - parameters = _load_checkpoint(files) - if named_tensors: - parameters.update(named_tensors) - bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) - - class MemoryBucket(BaseModel): - size: int - metas: list[ParameterMeta] - - buckets: list[MemoryBucket] = [] - buckets.append(MemoryBucket(size=0, metas=[])) - for name, tensor in sorted(parameters.items()): - size = _align_size(tensor.dtype, tensor.shape) - if buckets[-1].size + size > bucket_size: - assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" - buckets.append(MemoryBucket(size=0, metas=[])) - buckets[-1].metas.append( - ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size) - ) - buckets[-1].size += size - - memory_buffers = [ - MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) - for bucket in buckets - ] - - def register_pin_memory( - idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None - ) -> tuple[int, torch.Tensor]: - if shared_pin_memory: - # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one - # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time - assert idx < len(shared_pin_memory), ( - f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" - ) - assert shared_pin_memory[idx].size == size, ( - f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" - ) - return idx, shared_pin_memory[idx].buffer - else: - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) - return idx, buffer - - def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): - buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) - - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - futures = [ - executor.submit( - register_pin_memory, - idx, - bucket.size, - shared_pin_memory, - ) - for idx, bucket in enumerate(buckets) - ] - new_futures = [] - for future in concurrent.futures.as_completed(futures): - idx, buffer = future.result() - assert buffer.numel() == buckets[idx].size, ( - f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" - ) - memory_buffers[idx].buffer = buffer - logger.info( - f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " - f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" - ) - offset = 0 - for meta in buckets[idx].metas: - name = meta.name - tensor = parameters[name] - size = _align_size(tensor.dtype, tensor.shape) - assert size == _align_size(meta.dtype, meta.shape), ( - f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}" - ) - new_futures.append(executor.submit(register_tensor, buffer, offset, tensor)) - offset += size - for future in concurrent.futures.as_completed(new_futures): - future.result() - return memory_buffers - - -def _register_checkpoint( - *, - files: list[str], - named_tensors: dict[str, torch.Tensor], - rank: int | None = None, - shared_pin_memory: list[MemoryBuffer] | None = None, - inplace_pin: bool = False, -) -> list[MemoryBuffer]: - logger.info( - f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" - ) - if not files and not named_tensors: - return [] - memory_buffers: list[MemoryBuffer] = [] - if inplace_pin: - logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files") - files_to_inplace_pin = [ - file - for file in files - if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 - ] - files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] - else: - files_to_normal_pin = files - files_to_inplace_pin = [] - if files_to_normal_pin or named_tensors: - memory_buffers.extend( - _normal_pin_memory( - files=files_to_normal_pin, - named_tensors=named_tensors, - rank=rank, - shared_pin_memory=shared_pin_memory, - ) - ) - if files_to_inplace_pin: - memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank)) - return memory_buffers - - -def request_inference_to_update( - url: str, - socket_paths: dict[str, str], - timeout: float = 300.0, - uds: str | None = None, -): - """Send an inference update request to inference server via HTTP or Unix socket. - - Args: - url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to. - socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights. - timeout (float, optional): Request timeout in seconds. Defaults to 300.0. - uds (str, optional): Path to a Unix domain socket. If provided, the request - will be sent via the Unix socket instead of HTTP. Defaults to None. - - Raises: - httpx.HTTPStatusError: If the response contains an HTTP error status. - httpx.RequestError: If there was an issue while making the request. - """ - resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post( - url, - json={ - "method": "update_weights_from_ipc", - "args": [socket_paths], - "timeout": timeout, - }, - timeout=timeout, - ) - resp.raise_for_status() - - def _gen_h2d_buckets( global_metas: dict[int, MemoryBufferMetaList], bucket_size: int, @@ -793,80 +162,6 @@ def _get_master_port(master_port: int | None = None) -> int: return master_port -class P2PStore: - def __init__(self, device_manager: DeviceManager): - from mooncake.engine import TransferEngine - - self.rank = int(os.getenv("RANK")) - gpu_count = device_manager.device_module.device_count() - local_rank = self.rank % gpu_count - device_type = device_manager.device_type - if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None: - self.device = "" - else: - self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) - self.ip = get_ip() - - # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases - retry_count = 8 - for i in range(retry_count): - self.engine = TransferEngine() - ret = self.engine.initialize( - self.ip, - "P2PHANDSHAKE", - "ascend_direct" if device_type == "npu" else "rdma", - self.device, - ) - if ret == 0: - break - # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time - sleep_ms = random.randint(500, 2000) - logger.warning( - f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms" - ) - time.sleep(sleep_ms / 1000) - else: - raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine") - self.port = self.engine.get_rpc_port() - self.named_tensors: dict[str, torch.Tensor] = {} - logger.info( - f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}" - ) - - @property - def addr(self) -> str: - return f"{self.ip}:{self.port}" - - def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]): - buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()] - capacities = [tensor.nbytes for tensor in named_tensors.values()] - self.named_tensors.update(named_tensors) - for i, name in enumerate(named_tensors.keys()): - logger.info( - f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}" - ) - assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0 - - def unregister_named_tensors(self, names: list[str]) -> int: - buffer_addresses = [self.named_tensors[name].data_ptr() for name in names] - assert self.engine.batch_unregister_memory(buffer_addresses) == 0 - num_unregistered = 0 - for i, name in enumerate(names): - del self.named_tensors[name] - logger.info( - f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}" - ) - num_unregistered += 1 - return num_unregistered - - def batch_transfer_sync_read( - self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int] - ): - assert ( - self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0 - ) - - class ParameterServer: shared_memory_pool_name = "__shared_memory_pool__" @@ -1559,79 +854,8 @@ def _update_per_bucket( self.device_manager.device_module.empty_cache() -def _init_api(ps: ParameterServer) -> Any: - import fastapi - from fastapi import Request - from fastapi.responses import JSONResponse, Response - - app = fastapi.FastAPI() - - class RegisterRequest(BaseModel): - files: list[str] - - class UpdateRequest(BaseModel): - ranks: list[int] = [] - update_url: str | None = None - inference_group_ranks: list[int] = [] - timeout: float = 300.0 - uds: str | None = None - - def wrap_exception(func: Callable[[], None]) -> Response: - try: - func() - except Exception as e: # noqa: BLE001 - logger.exception(f"wrap exception {func} failed") - return JSONResponse(content=str(e), status_code=500) - return Response(status_code=200) - - @app.post("/v1/checkpoints/{checkpoint_name}/files") - async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response: - return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files)) - - @app.delete("/v1/checkpoints/{checkpoint_name}") - async def unregister_checkpoint(checkpoint_name: str) -> Response: - return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name)) - - @app.get("/v1/healthz") - async def healthz() -> Response: - return Response(status_code=200) - - @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas") - async def gather_metas(checkpoint_name: str) -> Response: - return wrap_exception(lambda: ps.gather_metas(checkpoint_name)) - - @app.post("/v1/checkpoints/{checkpoint_name}/update") - async def update(checkpoint_name: str, req: UpdateRequest) -> Response: - def update_func(socket_paths: list[tuple[str, str]]): - if req.update_url is None: - return - if req.inference_group_ranks: - socket_paths = [socket_paths[i] for i in req.inference_group_ranks] - request_inference_to_update( - req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds - ) - - return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks)) - - return app - - -@logger.catch(reraise=True) -def run_from_cli(): - import uvicorn - - parser = argparse.ArgumentParser(description="Parameter Server") - parser.add_argument("--uds", type=str) - - args = parser.parse_args() - logger.info( - f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}" - ) - - assert args.uds and len(args.uds) > 0, args.uds - ps = ParameterServer(auto_pg=True) - uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60) - - +# we need this CLI entry point for compatibility with former versions if __name__ == "__main__": + from .__main__ import run_from_cli + run_from_cli() diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 9b0951a..0b4130d 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -3,7 +3,7 @@ import pytest -from checkpoint_engine.ps import ( +from checkpoint_engine.p2p_store import ( _get_my_rdma_device, _get_rdma_devices, _ibv_get_device_list, @@ -42,7 +42,9 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): """Test _get_rdma_devices with no environment variables""" with ( patch.dict(os.environ, clear=True), - patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + patch( + "checkpoint_engine.p2p_store._ibv_get_device_list", return_value=mock_available_devices + ), ): devices = _get_rdma_devices() assert sorted(devices) == sorted(mock_available_devices) @@ -121,7 +123,7 @@ def test_parse_exact_match_with_nonexistent_device( mock_available_devices: list[str], ): """Test exact matching with non-existent device""" - with patch("checkpoint_engine.ps.logger") as mock_logger: + with patch("checkpoint_engine.p2p_store.logger") as mock_logger: result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) assert result == expected_result mock_logger.warning.assert_called_once_with(expected_warning) @@ -148,7 +150,9 @@ def test_get_rdma_devices_with_env_vars( env_dict = {env_var_name: env_var_value} with ( patch.dict(os.environ, env_dict), - patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + patch( + "checkpoint_engine.p2p_store._ibv_get_device_list", return_value=mock_available_devices + ), ): devices = _get_rdma_devices() assert sorted(devices) == sorted(expected_devices)