diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 37143df..52229d4 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import argparse import concurrent.futures import ctypes @@ -10,6 +8,7 @@ import threading import time from collections import defaultdict +from collections.abc import Callable from datetime import timedelta from functools import lru_cache from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple @@ -26,8 +25,6 @@ if TYPE_CHECKING: - from collections.abc import Callable - from typing_extensions import TypedDict class FileMeta(TypedDict): @@ -151,8 +148,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: return ret -def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]: - def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]: +def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]: + def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: ret = {} with safe_open(fn, framework="pt") as f: for name in f.keys(): # noqa: SIM118 @@ -168,7 +165,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]: return ret # deprecated, will be removed in the future - def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]: + def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]: """load *.np file and return memmap and related tensor meta""" def parse_npy_header(fin: BinaryIO) -> dict[str, Any]: