@@ -34,7 +34,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
3434
3535 loaded = restore_grug_state_from_checkpoint (
3636 {"state" : "init" },
37- checkpoint_path = str (checkpoint_root ),
37+ checkpoint_search_paths = [ str (checkpoint_root )] ,
3838 load_checkpoint_setting = True ,
3939 mesh = None ,
4040 allow_partial = False ,
@@ -61,7 +61,7 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
6161
6262 loaded = restore_grug_state_from_checkpoint (
6363 {"state" : "init" },
64- checkpoint_path = str (checkpoint_root ),
64+ checkpoint_search_paths = [ str (checkpoint_root )] ,
6565 load_checkpoint_setting = None ,
6666 mesh = None ,
6767 allow_partial = False ,
@@ -86,15 +86,15 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
8686 with pytest .raises (FileNotFoundError , match = "Could not load a checkpoint" ):
8787 restore_grug_state_from_checkpoint (
8888 {"state" : "init" },
89- checkpoint_path = str (checkpoint_root ),
89+ checkpoint_search_paths = [ str (checkpoint_root )] ,
9090 load_checkpoint_setting = True ,
9191 mesh = None ,
9292 allow_partial = False ,
9393 _load_fn = fake_load ,
9494 )
9595
9696
97- def test_restore_discovers_candidates_across_additional_paths (tmp_path : Path ):
97+ def test_restore_discovers_candidates_across_search_paths (tmp_path : Path ):
9898 permanent_root = tmp_path / "checkpoints"
9999 temp_root = tmp_path / "checkpoints-temp"
100100
@@ -109,20 +109,19 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
109109
110110 loaded = restore_grug_state_from_checkpoint (
111111 {"state" : "init" },
112- checkpoint_path = str (permanent_root ),
112+ checkpoint_search_paths = [ str (permanent_root ), str ( temp_root )] ,
113113 load_checkpoint_setting = True ,
114114 mesh = None ,
115115 allow_partial = False ,
116- additional_checkpoint_paths = [str (temp_root )],
117116 _load_fn = fake_load ,
118117 )
119118
120- # step-150 from temp root should be preferred (highest step)
119+ # step-150 from temp root should be preferred (highest step).
121120 assert attempted == [str (temp_root / "step-150" )]
122121 assert loaded == {"loaded_from" : str (temp_root / "step-150" )}
123122
124123
125- def test_restore_respects_explicit_checkpoint_path_with_additional_paths (tmp_path : Path ):
124+ def test_restore_respects_explicit_checkpoint_path_as_single_search_path (tmp_path : Path ):
126125 permanent_root = tmp_path / "checkpoints"
127126 temp_root = tmp_path / "checkpoints-temp"
128127 explicit_checkpoint = permanent_root / "step-100"
@@ -138,11 +137,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
138137
139138 loaded = restore_grug_state_from_checkpoint (
140139 {"state" : "init" },
141- checkpoint_path = str (explicit_checkpoint ),
140+ checkpoint_search_paths = [ str (explicit_checkpoint )] ,
142141 load_checkpoint_setting = True ,
143142 mesh = None ,
144143 allow_partial = False ,
145- additional_checkpoint_paths = [str (temp_root )],
146144 _load_fn = fake_load ,
147145 )
148146
@@ -167,11 +165,10 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
167165
168166 loaded = restore_grug_state_from_checkpoint (
169167 {"state" : "init" },
170- checkpoint_path = str (permanent_root ),
168+ checkpoint_search_paths = [ str (permanent_root ), str ( temp_root )] ,
171169 load_checkpoint_setting = None ,
172170 mesh = None ,
173171 allow_partial = False ,
174- additional_checkpoint_paths = [str (temp_root )],
175172 _load_fn = fake_load ,
176173 )
177174
@@ -197,7 +194,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path
197194
198195 loaded_legacy = restore_grug_state_from_checkpoint (
199196 template_state ,
200- checkpoint_path = str (checkpoint_root ),
197+ checkpoint_search_paths = [ str (checkpoint_root )] ,
201198 load_checkpoint_setting = True ,
202199 mesh = None ,
203200 allow_partial = False ,
@@ -210,7 +207,7 @@ def test_restore_supports_legacy_wrapped_and_current_checkpoint_formats(tmp_path
210207
211208 loaded_current = restore_grug_state_from_checkpoint (
212209 template_state ,
213- checkpoint_path = str (checkpoint_root ),
210+ checkpoint_search_paths = [ str (checkpoint_root )] ,
214211 load_checkpoint_setting = True ,
215212 mesh = None ,
216213 allow_partial = False ,
0 commit comments