66from collections import defaultdict
77from enum import Enum
88from pathlib import Path
9+ from typing import Any , DefaultDict
910
1011from ..mapping import CheckpointingException , ShardedStateDict , StateDict
1112from .async_utils import AsyncCallsQueue , AsyncRequest
@@ -18,7 +19,8 @@ class StrategyAction(Enum):
1819 SAVE_SHARDED = 'save_sharded'
1920
2021
21- default_strategies = defaultdict (dict )
22+ _import_trigger = None
23+ default_strategies : DefaultDict [str , dict [tuple , Any ]] = defaultdict (dict )
2224
2325async_calls = AsyncCallsQueue ()
2426
@@ -35,7 +37,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
3537 from .torch import _import_trigger
3638 except ImportError as e :
3739 raise CheckpointingException (
38- f'Cannot import a default strategy for: { (action .value , backend , version )} . Error: { e } . Hint: { error_hint } '
40+ f'Cannot import a default strategy for: { (action .value , backend , version )} . '
41+ f'Error: { e } . Hint: { error_hint } '
3942 ) from e
4043 try :
4144 return default_strategies [action .value ][(backend , version )]
@@ -46,7 +49,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
4649
4750
4851class LoadStrategyBase (ABC ):
49- """Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version."""
52+ """Base class for a load strategy. Requires implementing checks for compatibility with a
53+ given checkpoint version."""
5054
5155 @abstractmethod
5256 def check_backend_compatibility (self , loaded_version ):
@@ -63,7 +67,8 @@ def can_handle_sharded_objects(self):
6367
6468
6569class SaveStrategyBase (ABC ):
66- """Base class for a save strategy. Requires defining a backend type and version of the saved format."""
70+ """Base class for a save strategy. Requires defining a backend type and
71+ version of the saved format."""
6772
6873 def __init__ (self , backend : str , version : int ):
6974 self .backend = backend
0 commit comments