Skip to content

Commit eac2269

Browse files
committed
Keep forced checkpoints permanent
1 parent 990af90 commit eac2269

2 files changed

Lines changed: 39 additions & 8 deletions

File tree

lib/levanter/src/levanter/checkpoint.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -521,14 +521,15 @@ def on_step(self, *, tree: PyTree, step: int, force: bool = False):
521521
my_should_save = force
522522
my_save_permanent_ckpt = force
523523

524-
current_every = self._get_current_step_save_interval(step)
525-
last_save_time = self._dt_now_injection() - self._last_save_time
526-
if current_every is not None and step % current_every == 0:
527-
my_should_save = True
528-
my_save_permanent_ckpt = True
529-
elif self.save_interval and last_save_time >= self.save_interval:
530-
my_should_save = True
531-
my_save_permanent_ckpt = False
524+
if not force:
525+
current_every = self._get_current_step_save_interval(step)
526+
last_save_time = self._dt_now_injection() - self._last_save_time
527+
if current_every is not None and step % current_every == 0:
528+
my_should_save = True
529+
my_save_permanent_ckpt = True
530+
elif self.save_interval and last_save_time >= self.save_interval:
531+
my_should_save = True
532+
my_save_permanent_ckpt = False
532533

533534
should_save, save_permanent_ckpt = broadcast_one_to_all(
534535
jnp.array([my_should_save, my_save_permanent_ckpt], dtype=jnp.bool_)

lib/levanter/tests/test_checkpoint.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,36 @@ def advance_time(delta_seconds):
540540
assert _get_checkpoint_steps(tmpdir) == [1, 2]
541541

542542

543+
def test_checkpointer_force_save_uses_permanent_path_even_when_time_policy_elapsed():
544+
fake_now = datetime.datetime(2021, 1, 1, 0, 0, 0)
545+
tick = 10
546+
547+
def advance_time(delta_seconds):
548+
nonlocal fake_now
549+
fake_now += timedelta(seconds=delta_seconds)
550+
551+
with (
552+
tempfile.TemporaryDirectory(prefix="checkpoints") as permanent_dir,
553+
tempfile.TemporaryDirectory(prefix="temp_checkpoints") as temporary_dir,
554+
):
555+
checkpointer = Checkpointer(
556+
permanent_dir,
557+
timedelta(seconds=tick),
558+
[],
559+
temporary_base_path=temporary_dir,
560+
dt_now_injection=lambda: fake_now,
561+
)
562+
563+
_on_step(checkpointer, 0)
564+
565+
advance_time(tick)
566+
_on_step(checkpointer, 1, force=True)
567+
checkpointer.wait_until_finished()
568+
569+
assert _get_checkpoint_steps(permanent_dir) == [1]
570+
assert list(pathlib.Path(temporary_dir).iterdir()) == []
571+
572+
543573
def test_load_from_checkpoint_or_initialize():
544574
In = Axis("in", 2)
545575
Out = Axis("out", 1)

0 commit comments

Comments
 (0)