Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import argparse
import concurrent.futures
import ctypes
Expand All @@ -10,6 +8,7 @@
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from datetime import timedelta
from functools import lru_cache
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
Expand All @@ -26,8 +25,6 @@


if TYPE_CHECKING:
from collections.abc import Callable

from typing_extensions import TypedDict

class FileMeta(TypedDict):
Expand Down Expand Up @@ -151,8 +148,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
return ret


def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
ret = {}
with safe_open(fn, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
Expand All @@ -168,7 +165,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
return ret

# deprecated, will be removed in the future
def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
"""load *.np file and return memmap and related tensor meta"""

def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
Expand Down