|
19 | 19 | import torch |
20 | 20 | import torch.distributed as dist |
21 | 21 | import zmq |
| 22 | +from fastapi.encoders import jsonable_encoder |
22 | 23 | from loguru import logger |
23 | 24 | from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema |
24 | 25 | from safetensors.torch import safe_open |
@@ -1113,11 +1114,23 @@ class UpdateRequest(BaseModel): |
1113 | 1114 |
|
1114 | 1115 | def wrap_exception(func: Callable[[], None]) -> Response: |
1115 | 1116 | try: |
1116 | | - func() |
| 1117 | + ret = func() |
| 1118 | + return ( |
| 1119 | + Response(status_code=200) |
| 1120 | + if ret is None |
| 1121 | + else JSONResponse(jsonable_encoder(ret), status_code=200) |
| 1122 | + ) |
1117 | 1123 | except Exception as e: # noqa: BLE001 |
1118 | 1124 | logger.exception(f"wrap exception {func} failed") |
1119 | 1125 | return JSONResponse(content=str(e), status_code=500) |
1120 | | - return Response(status_code=200) |
| 1126 | + |
| 1127 | + @app.get("/v1/checkpoints/metas") |
| 1128 | + async def get_metas() -> Response: |
| 1129 | + return wrap_exception(lambda: ps.get_metas()) |
| 1130 | + |
| 1131 | + @app.post("/v1/checkpoints/metas") |
| 1132 | + async def load_metas(req: dict[int, MemoryBufferMetaList]) -> Response: |
| 1133 | + return wrap_exception(lambda: ps.load_metas(req)) |
1121 | 1134 |
|
1122 | 1135 | @app.post("/v1/checkpoints/{checkpoint_name}/files") |
1123 | 1136 | async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response: |
|
0 commit comments