Skip to content

Commit e1e0ec5

Browse files
committed
fix register_files fastapi parameter parse error
1 parent 716c0da commit e1e0ec5

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

checkpoint_engine/ps.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import argparse
42
import concurrent.futures
53
import ctypes
@@ -10,6 +8,7 @@
108
import threading
119
import time
1210
from collections import defaultdict
11+
from collections.abc import Callable
1312
from datetime import timedelta
1413
from functools import lru_cache
1514
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
@@ -26,8 +25,6 @@
2625

2726

2827
if TYPE_CHECKING:
29-
from collections.abc import Callable
30-
3128
from typing_extensions import TypedDict
3229

3330
class FileMeta(TypedDict):
@@ -151,8 +148,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
151148
return ret
152149

153150

154-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
155-
def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
151+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple['FileMeta', torch.Tensor]]]:
152+
def _safetensors_load(fn: str) -> dict[str, tuple['FileMeta', torch.Tensor]]:
156153
ret = {}
157154
with safe_open(fn, framework="pt") as f:
158155
for name in f.keys(): # noqa: SIM118
@@ -168,7 +165,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
168165
return ret
169166

170167
# deprecated, will be removed in the future
171-
def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
168+
def _fast_np_load(fn: str) -> dict[str, tuple['FileMeta', torch.Tensor]]:
172169
"""load *.np file and return memmap and related tensor meta"""
173170

174171
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:

0 commit comments

Comments
 (0)