@@ -122,6 +122,34 @@ def fake_load(state, path, *, discover_latest, axis_mapping, mesh, allow_partial
122122 assert loaded == {"loaded_from" : str (temp_root / "step-150" )}
123123
124124
125+ def test_restore_respects_explicit_checkpoint_path_with_additional_paths (tmp_path : Path ):
126+ permanent_root = tmp_path / "checkpoints"
127+ temp_root = tmp_path / "checkpoints-temp"
128+ explicit_checkpoint = permanent_root / "step-100"
129+
130+ _write_checkpoint_metadata (explicit_checkpoint , step = 100 , timestamp = "2026-03-17T00:00:00" )
131+ _write_checkpoint_metadata (temp_root / "step-150" , step = 150 , timestamp = "2026-03-17T06:00:00" )
132+
133+ attempted : list [str ] = []
134+
135+ def fake_load (state , path , * , discover_latest , axis_mapping , mesh , allow_partial ):
136+ attempted .append (path )
137+ return {"loaded_from" : path }
138+
139+ loaded = restore_grug_state_from_checkpoint (
140+ {"state" : "init" },
141+ checkpoint_path = str (explicit_checkpoint ),
142+ load_checkpoint_setting = True ,
143+ mesh = None ,
144+ allow_partial = False ,
145+ additional_checkpoint_paths = [str (temp_root )],
146+ _load_fn = fake_load ,
147+ )
148+
149+ assert attempted == [str (explicit_checkpoint )]
150+ assert loaded == {"loaded_from" : str (explicit_checkpoint )}
151+
152+
125153def test_restore_falls_back_from_temp_to_permanent (tmp_path : Path ):
126154 permanent_root = tmp_path / "checkpoints"
127155 temp_root = tmp_path / "checkpoints-temp"
0 commit comments