1- from __future__ import annotations
2-
31import argparse
42import concurrent .futures
53import ctypes
108import threading
119import time
1210from collections import defaultdict
11+ from collections .abc import Callable
1312from datetime import timedelta
1413from functools import lru_cache
1514from typing import TYPE_CHECKING , Annotated , Any , BinaryIO , NamedTuple
2625
2726
2827if TYPE_CHECKING :
29- from collections .abc import Callable
30-
3128 from typing_extensions import TypedDict
3229
3330 class FileMeta (TypedDict ):
@@ -151,8 +148,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
151148 return ret
152149
153150
154- def _load_checkpoint_file (file_path : str ) -> tuple [int , dict [str , tuple [FileMeta , torch .Tensor ]]]:
155- def _safetensors_load (fn : str ) -> dict [str , tuple [FileMeta , torch .Tensor ]]:
151+ def _load_checkpoint_file (file_path : str ) -> tuple [int , dict [str , tuple [' FileMeta' , torch .Tensor ]]]:
152+ def _safetensors_load (fn : str ) -> dict [str , tuple [' FileMeta' , torch .Tensor ]]:
156153 ret = {}
157154 with safe_open (fn , framework = "pt" ) as f :
158155 for name in f .keys (): # noqa: SIM118
@@ -168,7 +165,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
168165 return ret
169166
170167 # deprecated, will be removed in the future
171- def _fast_np_load (fn : str ) -> dict [str , tuple [FileMeta , torch .Tensor ]]:
168+ def _fast_np_load (fn : str ) -> dict [str , tuple [' FileMeta' , torch .Tensor ]]:
172169 """load *.np file and return memmap and related tensor meta"""
173170
174171 def parse_npy_header (fin : BinaryIO ) -> dict [str , Any ]:
0 commit comments