Skip to content

Commit 3bebb19

Browse files
committed
feat: add get_metas and load_metas interface
1 parent 05827aa commit 3bebb19

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

checkpoint_engine/ps.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.distributed as dist
2121
import zmq
22+
from fastapi.encoders import jsonable_encoder
2223
from loguru import logger
2324
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema
2425
from safetensors.torch import safe_open
@@ -1113,11 +1114,23 @@ class UpdateRequest(BaseModel):
11131114

11141115
def wrap_exception(func: Callable[[], None]) -> Response:
11151116
try:
1116-
func()
1117+
ret = func()
1118+
return (
1119+
Response(status_code=200)
1120+
if ret is None
1121+
else JSONResponse(jsonable_encoder(ret), status_code=200)
1122+
)
11171123
except Exception as e: # noqa: BLE001
11181124
logger.exception(f"wrap exception {func} failed")
11191125
return JSONResponse(content=str(e), status_code=500)
1120-
return Response(status_code=200)
1126+
1127+
@app.get("/v1/checkpoints/metas")
1128+
async def get_metas() -> Response:
1129+
return wrap_exception(lambda: ps.get_metas())
1130+
1131+
@app.post("/v1/checkpoints/metas")
1132+
async def load_metas(req: dict[int, MemoryBufferMetaList]) -> Response:
1133+
return wrap_exception(lambda: ps.load_metas(req))
11211134

11221135
@app.post("/v1/checkpoints/{checkpoint_name}/files")
11231136
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:

0 commit comments

Comments
 (0)