Skip to content

Commit e48c442

Browse files
committed
feat: split api.py from ps.py
1 parent 0696dd6 commit e48c442

File tree

2 files changed

+103
-91
lines changed

2 files changed

+103
-91
lines changed

checkpoint_engine/api.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import httpx
6+
from loguru import logger
7+
from pydantic import BaseModel
8+
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Callable
12+
13+
from checkpoint_engine.ps import ParameterServer
14+
15+
16+
def request_inference_to_update(
17+
url: str,
18+
socket_paths: dict[str, str],
19+
timeout: float = 300.0,
20+
uds: str | None = None,
21+
):
22+
"""Send an inference update request to inference server via HTTP or Unix socket.
23+
24+
Args:
25+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
26+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
27+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
28+
uds (str, optional): Path to a Unix domain socket. If provided, the request
29+
will be sent via the Unix socket instead of HTTP. Defaults to None.
30+
31+
Raises:
32+
httpx.HTTPStatusError: If the response contains an HTTP error status.
33+
httpx.RequestError: If there was an issue while making the request.
34+
"""
35+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
36+
url,
37+
json={
38+
"method": "update_weights_from_ipc",
39+
"args": [socket_paths],
40+
"timeout": timeout,
41+
},
42+
timeout=timeout,
43+
)
44+
resp.raise_for_status()
45+
46+
47+
def _init_api(ps: ParameterServer) -> Any:
48+
import fastapi
49+
from fastapi import Request
50+
from fastapi.responses import JSONResponse, Response
51+
52+
app = fastapi.FastAPI()
53+
54+
class RegisterRequest(BaseModel):
55+
files: list[str]
56+
57+
class UpdateRequest(BaseModel):
58+
ranks: list[int] = []
59+
update_url: str | None = None
60+
inference_group_ranks: list[int] = []
61+
timeout: float = 300.0
62+
uds: str | None = None
63+
64+
def wrap_exception(func: Callable[[], None]) -> Response:
65+
try:
66+
func()
67+
except Exception as e: # noqa: BLE001
68+
logger.exception(f"wrap exception {func} failed")
69+
return JSONResponse(content=str(e), status_code=500)
70+
return Response(status_code=200)
71+
72+
@app.post("/v1/checkpoints/{checkpoint_name}/files")
73+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
74+
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
75+
76+
@app.delete("/v1/checkpoints/{checkpoint_name}")
77+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
78+
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
79+
80+
@app.get("/v1/healthz")
81+
async def healthz() -> Response:
82+
return Response(status_code=200)
83+
84+
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
85+
async def gather_metas(checkpoint_name: str) -> Response:
86+
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
87+
88+
@app.post("/v1/checkpoints/{checkpoint_name}/update")
89+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
90+
def update_func(socket_paths: list[tuple[str, str]]):
91+
if req.update_url is None:
92+
return
93+
if req.inference_group_ranks:
94+
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
95+
request_inference_to_update(
96+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
97+
)
98+
99+
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
100+
101+
return app

checkpoint_engine/ps.py

Lines changed: 2 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55
from collections import defaultdict
66
from collections.abc import Callable
77
from datetime import timedelta
8-
from typing import TYPE_CHECKING, Any
8+
from typing import TYPE_CHECKING
99

10-
import httpx
1110
import torch
1211
import torch.distributed as dist
1312
import zmq
1413
from loguru import logger
15-
from pydantic import BaseModel
1614
from torch.multiprocessing.reductions import reduce_tensor
1715

16+
from checkpoint_engine.api import _init_api
1817
from checkpoint_engine.data_types import (
1918
BucketRange,
2019
DataToGather,
@@ -59,37 +58,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
5958
raise ValueError(f"fail to get physical gpu id {device_index}") from e
6059

6160

62-
def request_inference_to_update(
63-
url: str,
64-
socket_paths: dict[str, str],
65-
timeout: float = 300.0,
66-
uds: str | None = None,
67-
):
68-
"""Send an inference update request to inference server via HTTP or Unix socket.
69-
70-
Args:
71-
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
72-
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
73-
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
74-
uds (str, optional): Path to a Unix domain socket. If provided, the request
75-
will be sent via the Unix socket instead of HTTP. Defaults to None.
76-
77-
Raises:
78-
httpx.HTTPStatusError: If the response contains an HTTP error status.
79-
httpx.RequestError: If there was an issue while making the request.
80-
"""
81-
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
82-
url,
83-
json={
84-
"method": "update_weights_from_ipc",
85-
"args": [socket_paths],
86-
"timeout": timeout,
87-
},
88-
timeout=timeout,
89-
)
90-
resp.raise_for_status()
91-
92-
9361
def _gen_h2d_buckets(
9462
global_metas: dict[int, MemoryBufferMetaList],
9563
bucket_size: int,
@@ -856,63 +824,6 @@ def _update_per_bucket(
856824
self.device_manager.device_module.empty_cache()
857825

858826

859-
def _init_api(ps: ParameterServer) -> Any:
860-
import fastapi
861-
from fastapi import Request
862-
from fastapi.responses import JSONResponse, Response
863-
864-
app = fastapi.FastAPI()
865-
866-
class RegisterRequest(BaseModel):
867-
files: list[str]
868-
869-
class UpdateRequest(BaseModel):
870-
ranks: list[int] = []
871-
update_url: str | None = None
872-
inference_group_ranks: list[int] = []
873-
timeout: float = 300.0
874-
uds: str | None = None
875-
876-
def wrap_exception(func: Callable[[], None]) -> Response:
877-
try:
878-
func()
879-
except Exception as e: # noqa: BLE001
880-
logger.exception(f"wrap exception {func} failed")
881-
return JSONResponse(content=str(e), status_code=500)
882-
return Response(status_code=200)
883-
884-
@app.post("/v1/checkpoints/{checkpoint_name}/files")
885-
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
886-
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
887-
888-
@app.delete("/v1/checkpoints/{checkpoint_name}")
889-
async def unregister_checkpoint(checkpoint_name: str) -> Response:
890-
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
891-
892-
@app.get("/v1/healthz")
893-
async def healthz() -> Response:
894-
return Response(status_code=200)
895-
896-
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
897-
async def gather_metas(checkpoint_name: str) -> Response:
898-
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
899-
900-
@app.post("/v1/checkpoints/{checkpoint_name}/update")
901-
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
902-
def update_func(socket_paths: list[tuple[str, str]]):
903-
if req.update_url is None:
904-
return
905-
if req.inference_group_ranks:
906-
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
907-
request_inference_to_update(
908-
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
909-
)
910-
911-
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
912-
913-
return app
914-
915-
916827
@logger.catch(reraise=True)
917828
def run_from_cli():
918829
import uvicorn

0 commit comments

Comments
 (0)