|
5 | 5 | from collections import defaultdict |
6 | 6 | from collections.abc import Callable |
7 | 7 | from datetime import timedelta |
8 | | -from typing import TYPE_CHECKING, Any |
| 8 | +from typing import TYPE_CHECKING |
9 | 9 |
|
10 | | -import httpx |
11 | 10 | import torch |
12 | 11 | import torch.distributed as dist |
13 | 12 | import zmq |
14 | 13 | from loguru import logger |
15 | | -from pydantic import BaseModel |
16 | 14 | from torch.multiprocessing.reductions import reduce_tensor |
17 | 15 |
|
| 16 | +from checkpoint_engine.api import _init_api |
18 | 17 | from checkpoint_engine.data_types import ( |
19 | 18 | BucketRange, |
20 | 19 | DataToGather, |
@@ -59,37 +58,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None |
59 | 58 | raise ValueError(f"fail to get physical gpu id {device_index}") from e |
60 | 59 |
|
61 | 60 |
|
62 | | -def request_inference_to_update( |
63 | | - url: str, |
64 | | - socket_paths: dict[str, str], |
65 | | - timeout: float = 300.0, |
66 | | - uds: str | None = None, |
67 | | -): |
68 | | - """Send an inference update request to inference server via HTTP or Unix socket. |
69 | | -
|
70 | | - Args: |
71 | | - url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to. |
72 | | - socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights. |
73 | | - timeout (float, optional): Request timeout in seconds. Defaults to 300.0. |
74 | | - uds (str, optional): Path to a Unix domain socket. If provided, the request |
75 | | - will be sent via the Unix socket instead of HTTP. Defaults to None. |
76 | | -
|
77 | | - Raises: |
78 | | - httpx.HTTPStatusError: If the response contains an HTTP error status. |
79 | | - httpx.RequestError: If there was an issue while making the request. |
80 | | - """ |
81 | | - resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post( |
82 | | - url, |
83 | | - json={ |
84 | | - "method": "update_weights_from_ipc", |
85 | | - "args": [socket_paths], |
86 | | - "timeout": timeout, |
87 | | - }, |
88 | | - timeout=timeout, |
89 | | - ) |
90 | | - resp.raise_for_status() |
91 | | - |
92 | | - |
93 | 61 | def _gen_h2d_buckets( |
94 | 62 | global_metas: dict[int, MemoryBufferMetaList], |
95 | 63 | bucket_size: int, |
@@ -856,63 +824,6 @@ def _update_per_bucket( |
856 | 824 | self.device_manager.device_module.empty_cache() |
857 | 825 |
|
858 | 826 |
|
859 | | -def _init_api(ps: ParameterServer) -> Any: |
860 | | - import fastapi |
861 | | - from fastapi import Request |
862 | | - from fastapi.responses import JSONResponse, Response |
863 | | - |
864 | | - app = fastapi.FastAPI() |
865 | | - |
866 | | - class RegisterRequest(BaseModel): |
867 | | - files: list[str] |
868 | | - |
869 | | - class UpdateRequest(BaseModel): |
870 | | - ranks: list[int] = [] |
871 | | - update_url: str | None = None |
872 | | - inference_group_ranks: list[int] = [] |
873 | | - timeout: float = 300.0 |
874 | | - uds: str | None = None |
875 | | - |
876 | | - def wrap_exception(func: Callable[[], None]) -> Response: |
877 | | - try: |
878 | | - func() |
879 | | - except Exception as e: # noqa: BLE001 |
880 | | - logger.exception(f"wrap exception {func} failed") |
881 | | - return JSONResponse(content=str(e), status_code=500) |
882 | | - return Response(status_code=200) |
883 | | - |
884 | | - @app.post("/v1/checkpoints/{checkpoint_name}/files") |
885 | | - async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response: |
886 | | - return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files)) |
887 | | - |
888 | | - @app.delete("/v1/checkpoints/{checkpoint_name}") |
889 | | - async def unregister_checkpoint(checkpoint_name: str) -> Response: |
890 | | - return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name)) |
891 | | - |
892 | | - @app.get("/v1/healthz") |
893 | | - async def healthz() -> Response: |
894 | | - return Response(status_code=200) |
895 | | - |
896 | | - @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas") |
897 | | - async def gather_metas(checkpoint_name: str) -> Response: |
898 | | - return wrap_exception(lambda: ps.gather_metas(checkpoint_name)) |
899 | | - |
900 | | - @app.post("/v1/checkpoints/{checkpoint_name}/update") |
901 | | - async def update(checkpoint_name: str, req: UpdateRequest) -> Response: |
902 | | - def update_func(socket_paths: list[tuple[str, str]]): |
903 | | - if req.update_url is None: |
904 | | - return |
905 | | - if req.inference_group_ranks: |
906 | | - socket_paths = [socket_paths[i] for i in req.inference_group_ranks] |
907 | | - request_inference_to_update( |
908 | | - req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds |
909 | | - ) |
910 | | - |
911 | | - return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks)) |
912 | | - |
913 | | - return app |
914 | | - |
915 | | - |
916 | 827 | @logger.catch(reraise=True) |
917 | 828 | def run_from_cli(): |
918 | 829 | import uvicorn |
|
0 commit comments