Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/flowgym/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,9 @@ def save_final(
observed validation metrics.

Returns:
The on-disk checkpoint path, or ``None`` when orbax rejected
the save (duplicate step / backward step).
The on-disk checkpoint path (the freshly written step, or an
already-persisted one at the same step), or ``None`` when
orbax rejected a genuinely backward save.
"""
metrics: Mapping[str, float] | None
if val_metrics is not None:
Expand All @@ -400,7 +401,17 @@ def save_final(
metrics = self._last_observed or None
else:
metrics = self._last_observed or None
path = self._save(state, step, metrics)
# The final step may already be on disk when training ends on a
# ``save_every`` boundary that ``save_periodic`` just wrote.
# Orbax would reject the duplicate write (``_save`` -> ``None``),
# dropping the ``final`` alias upload and emitting a spurious
# warning. Reuse the existing checkpoint instead so the final
# artifact is always tagged regardless of ``num_episodes`` /
# ``save_every`` alignment.
if step in self._mngr.all_steps():
path: str | None = str(self._ckpt_root / str(step))
else:
path = self._save(state, step, metrics)
if path is None:
return None
if self._cfg.wandb_upload != "never":
Expand Down
19 changes: 15 additions & 4 deletions tests/integration/training/test_full_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,13 @@ def create_state_fn(img, key):
def compute_estimate_fn(img, state, ts):
return state, {"dummy": jnp.array(0.0)}

# Track ReplayBuffer initialization
with patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer:
# Track ReplayBuffer initialization. Stub the Checkpointer: this test
# exercises ReplayBuffer wiring, not checkpointing, and the mocked
# env/sampler can't be serialized by a real orbax save.
with (
patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer,
patch("train.Checkpointer", side_effect=_FakeCheckpointer),
):
train(
estimator=model,
estimator_config={"config": {"jit": False}},
Expand Down Expand Up @@ -453,8 +458,14 @@ def tracking_train_step(*args, **kwargs):
train_step_calls.append((args, kwargs))
return original_train_step(*args, **kwargs)

with patch.object(
model, "create_train_step", return_value=tracking_train_step
# Stub the Checkpointer: this test exercises replay sampling, not
# checkpointing, and the mocked env/sampler can't be serialized by a
# real orbax save.
with (
patch.object(
model, "create_train_step", return_value=tracking_train_step
),
patch("train.Checkpointer", side_effect=_FakeCheckpointer),
):
train(
estimator=model,
Expand Down
16 changes: 13 additions & 3 deletions tests/integration/training/test_train_replay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def test_train_replay_initialization(mock_dependencies):
"""ReplayBuffer is instantiated once when the RL train loop starts."""
estimator, env, obs, env_state = mock_dependencies

# We want to check if ReplayBuffer is initialized
with patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer:
# We want to check if ReplayBuffer is initialized. Stub the
# Checkpointer: this test exercises ReplayBuffer wiring, not
# checkpointing, and the mocked deps can't be serialized by a real
# orbax save (the end-of-run save_final would otherwise try).
with (
patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer,
patch("train.Checkpointer"),
):
train(
estimator=estimator,
estimator_config={"config": {"jit": False}},
Expand All @@ -43,7 +49,11 @@ def test_train_replay_initialization(mock_dependencies):
mock_buffer.assert_called_once()


def test_train_replay_execution(mock_dependencies):
# Stub the Checkpointer: this test exercises replay sampling, not
# checkpointing, and the mocked deps can't be serialized by a real orbax
# save (the end-of-run save_final would otherwise try).
@patch("train.Checkpointer")
def test_train_replay_execution(_mock_checkpointer, mock_dependencies):
"""replay_ratio=1.0 fires the train step on live and replay data."""
estimator, env, obs, env_state = mock_dependencies
train_step_fn = estimator.create_train_step.return_value
Expand Down
41 changes: 32 additions & 9 deletions tests/unit/training/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def metrics(self, step):
def latest_step(self):
return self._steps[-1] if self._steps else None

def all_steps(self):
return list(self._steps)

def wait_until_finished(self):
self.wait_calls += 1

Expand Down Expand Up @@ -473,23 +476,43 @@ def test_save_final_falls_back_to_last_observed(tmp_path):
assert fake.save_args[-1]["metrics"] == {"mean_error": 0.3}


def test_orbax_rejects_duplicate_step_returns_none(tmp_path):
"""A backward ``save`` that reaches orbax reports None and warns.
def test_orbax_rejects_backward_unsaved_step_returns_none(tmp_path):
"""A genuinely backward ``save_final`` reports None and warns.

Uses ``save_final`` for the second call so the in-loop same-step
guard on ``save_periodic`` doesn't short-circuit before orbax.
A step behind ``latest`` that was never written is not on disk, so
it reaches orbax and exercises the duplicate/backward rejection.
"""
cfg = CheckpointConfig()
for ckpt, fake, logger_mock in _make(tmp_path, cfg=cfg):
assert ckpt.save_periodic(MagicMock(), step=10) is not None
# ``save_final`` does not guard same-step so the call reaches
# the fake and exercises orbax's duplicate-step rejection.
assert ckpt.save_periodic(MagicMock(), step=20) is not None
# step 10 < latest 20 and was never saved -> reaches orbax.
assert ckpt.save_final(MagicMock(), step=10) is None
assert len(fake.save_args) == 2
assert fake.save_args[1]["ok"] is False
assert fake.save_args[-1]["ok"] is False
logger_mock.warning.assert_called()


def test_save_final_reuses_already_persisted_step(tmp_path):
"""Regression for the Codex P2 on PR #51.

When training ends on a ``save_every`` boundary that ``save_periodic``
just wrote, ``save_final`` must not re-write the step (no duplicate
orbax write, no spurious warning) but must still return its path and
upload the ``final`` alias so the end-of-run artifact is always
tagged, regardless of ``num_episodes`` / ``save_every`` alignment.
"""
cfg = CheckpointConfig(wandb_upload="every")
for ckpt, fake, logger_mock in _make(tmp_path, cfg=cfg):
assert ckpt.save_periodic(MagicMock(), step=10) is not None
n_saves = len(fake.save_args)
out = ckpt.save_final(MagicMock(), step=10)
assert out is not None
# No second orbax write and no duplicate-step warning.
assert len(fake.save_args) == n_saves
logger_mock.warning.assert_not_called()
# The final alias was still uploaded for the existing checkpoint.
assert "final" in logger_mock.artifact.call_args.args[0]["aliases"]


def test_save_periodic_same_step_after_on_validation_silently_skips(tmp_path):
"""Regression for antonioterpin #158 review #4.

Expand Down