77import logging
88import os
99import urllib .parse
10- from collections .abc import Callable
10+ from collections .abc import Callable , Sequence
1111from typing import TypeVar
1212
1313import fsspec
@@ -26,19 +26,34 @@ def _get_fs_and_plain_path(path: str) -> tuple[AbstractFileSystem, str]:
2626 return fs , plain_path
2727
2828
29- def _checkpoint_candidates (checkpoint_path : str ) -> list [str ]:
30- fs , plain_path = _get_fs_and_plain_path (checkpoint_path )
31- base_path_protocol = urllib .parse .urlparse (checkpoint_path ).scheme
29+ def _checkpoint_candidates (checkpoint_search_paths : Sequence [str ]) -> list [str ]:
30+ candidates : list [tuple [int , str , str ]] = []
31+ for search_path in checkpoint_search_paths :
32+ candidates .extend (_scan_checkpoint_root (search_path ))
33+
34+ candidates .sort (key = lambda item : (item [0 ], item [1 ]), reverse = True )
35+ ordered_candidates = [candidate for _ , _ , candidate in candidates ]
36+
37+ for search_path in checkpoint_search_paths :
38+ if search_path not in ordered_candidates :
39+ ordered_candidates .append (search_path )
40+ return ordered_candidates
41+
42+
43+ def _scan_checkpoint_root (root_path : str ) -> list [tuple [int , str , str ]]:
44+ """Scan a single root path and return (step, timestamp, path) tuples."""
45+ fs , plain_path = _get_fs_and_plain_path (root_path )
46+ base_path_protocol = urllib .parse .urlparse (root_path ).scheme
3247
3348 def maybe_unstrip_protocol (path : str ) -> str :
3449 if base_path_protocol != "" and urllib .parse .urlparse (path ).scheme == "" :
3550 return f"{ base_path_protocol } ://{ path } "
3651 return path
3752
3853 checkpoint_dirs = [maybe_unstrip_protocol (d ) for d in fs .glob (os .path .join (plain_path , "*" )) if fs .isdir (d )]
39- checkpoint_dirs .append (checkpoint_path )
54+ checkpoint_dirs .append (root_path )
4055
41- candidates : list [tuple [int , str , str ]] = []
56+ results : list [tuple [int , str , str ]] = []
4257 for candidate in checkpoint_dirs :
4358 metadata_path = os .path .join (candidate , "metadata.json" )
4459 if not fs .exists (metadata_path ):
@@ -59,34 +74,29 @@ def maybe_unstrip_protocol(path: str) -> str:
5974
6075 timestamp = metadata .get ("timestamp" )
6176 timestamp_key = str (timestamp ) if timestamp is not None else ""
62- candidates .append ((step_num , timestamp_key , candidate ))
63-
64- candidates .sort (key = lambda item : (item [0 ], item [1 ]), reverse = True )
65- ordered_candidates = [candidate for _ , _ , candidate in candidates ]
66- if checkpoint_path not in ordered_candidates :
67- ordered_candidates .append (checkpoint_path )
77+ results .append ((step_num , timestamp_key , candidate ))
6878
69- return ordered_candidates
79+ return results
7080
7181
7282def restore_grug_state_from_checkpoint (
7383 state : StateT ,
7484 * ,
75- checkpoint_path : str | None ,
85+ checkpoint_search_paths : Sequence [ str ] ,
7686 load_checkpoint_setting : bool | None ,
7787 mesh : jax .sharding .Mesh | None ,
7888 allow_partial : bool ,
7989 _load_fn : Callable [..., StateT ] = load_checkpoint ,
8090) -> StateT :
81- if checkpoint_path is None :
91+ if not checkpoint_search_paths :
8292 if load_checkpoint_setting :
83- raise FileNotFoundError ("load_checkpoint=True but no checkpoint path is configured." )
93+ raise FileNotFoundError ("load_checkpoint=True but no checkpoint search paths are configured." )
8494 return state
8595
8696 if load_checkpoint_setting is False :
8797 return state
8898
89- candidates = _checkpoint_candidates (checkpoint_path )
99+ candidates = _checkpoint_candidates (checkpoint_search_paths )
90100 last_error : FileNotFoundError | None = None
91101
92102 for candidate in candidates :
@@ -98,8 +108,8 @@ def restore_grug_state_from_checkpoint(
98108 allow_partial = allow_partial ,
99109 load_fn = _load_fn ,
100110 )
101- if candidate != checkpoint_path :
102- logger .info ("Loaded checkpoint %s from %s" , checkpoint_path , candidate )
111+ if candidate not in checkpoint_search_paths :
112+ logger .info ("Loaded checkpoint from %s while searching %s" , candidate , checkpoint_search_paths )
103113 return loaded
104114 except FileNotFoundError as exc :
105115 last_error = exc
@@ -108,14 +118,15 @@ def restore_grug_state_from_checkpoint(
108118 )
109119
110120 if load_checkpoint_setting is True :
121+ search_path_summary = ", " .join (checkpoint_search_paths )
111122 attempted = ", " .join (candidates )
112123 if last_error is None :
113- raise FileNotFoundError (f"Could not find checkpoint at { checkpoint_path } " )
124+ raise FileNotFoundError (f"Could not find checkpoint under any of: { search_path_summary } " )
114125 raise FileNotFoundError (
115- f"Could not load a checkpoint from { checkpoint_path } . Attempted: { attempted } "
126+ f"Could not load a checkpoint from search paths { search_path_summary } . Attempted: { attempted } "
116127 ) from last_error
117128
118- logger .info (f "Checkpoint not found at { checkpoint_path } . Starting from scratch." )
129+ logger .info ("Checkpoint not found under %s . Starting from scratch." , checkpoint_search_paths )
119130 return state
120131
121132
@@ -131,7 +142,6 @@ def _load_candidate_state(
131142 return load_fn (
132143 state ,
133144 candidate ,
134- discover_latest = False ,
135145 axis_mapping = None ,
136146 mesh = mesh ,
137147 allow_partial = allow_partial ,
@@ -141,7 +151,6 @@ def _load_candidate_state(
141151 wrapped = load_fn (
142152 {"train_state" : state },
143153 candidate ,
144- discover_latest = False ,
145154 axis_mapping = None ,
146155 mesh = mesh ,
147156 allow_partial = allow_partial ,
0 commit comments