From 38f6dea77ceb0f73839a7477ae806a2ed8f6d479 Mon Sep 17 00:00:00 2001 From: Francesco Banelli Date: Sat, 30 May 2026 18:33:32 +0200 Subject: [PATCH] fix(checkpoint): stop RL replay tests driving real orbax with mocks; keep final alias MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses the second review round on PR #51. - RL replay integration tests (`test_train_rl_replay_buffer_used`, `test_train_rl_replay_buffer_samples`, `test_train_replay_initialization`, `test_train_replay_execution`) drove the real `Checkpointer` with a mocked sampler/model/state. The new always-on `save_final` then forced a real orbax save that serialized a `MagicMock` (grain iterator / state `extras`) and failed. These tests cover `ReplayBuffer`, not checkpointing, so they now patch `train.Checkpointer`. - `save_final`: when the final step is already on disk (training ended on a `save_every` boundary that `save_periodic` just wrote), reuse the existing checkpoint instead of letting orbax reject a duplicate write — which dropped the `final` W&B alias and logged a spurious warning. Fixes the Codex P2; the `final` artifact is now always tagged. - tests: replace the duplicate-step test with a backward-unsaved-step case plus a regression for the reused-final-step path; add `all_steps` to the fake manager. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/flowgym/checkpointing.py | 17 ++++++-- .../training/test_full_integration.py | 19 +++++++-- .../training/test_train_replay_integration.py | 16 ++++++-- tests/unit/training/test_checkpointer.py | 41 +++++++++++++++---- 4 files changed, 74 insertions(+), 19 deletions(-) diff --git a/src/flowgym/checkpointing.py b/src/flowgym/checkpointing.py index adad5e5..2848436 100644 --- a/src/flowgym/checkpointing.py +++ b/src/flowgym/checkpointing.py @@ -384,8 +384,9 @@ def save_final( observed validation metrics. Returns: - The on-disk checkpoint path, or ``None`` when orbax rejected - the save (duplicate step / backward step). + The on-disk checkpoint path (the freshly written step, or an + already-persisted one at the same step), or ``None`` when + orbax rejected a genuinely backward save. """ metrics: Mapping[str, float] | None if val_metrics is not None: @@ -400,7 +401,17 @@ def save_final( metrics = self._last_observed or None else: metrics = self._last_observed or None - path = self._save(state, step, metrics) + # The final step may already be on disk when training ends on a + # ``save_every`` boundary that ``save_periodic`` just wrote. + # Orbax would reject the duplicate write (``_save`` -> ``None``), + # dropping the ``final`` alias upload and emitting a spurious + # warning. Reuse the existing checkpoint instead so the final + # artifact is always tagged regardless of ``num_episodes`` / + # ``save_every`` alignment. + if step in self._mngr.all_steps(): + path: str | None = str(self._ckpt_root / str(step)) + else: + path = self._save(state, step, metrics) if path is None: return None if self._cfg.wandb_upload != "never": diff --git a/tests/integration/training/test_full_integration.py b/tests/integration/training/test_full_integration.py index 268ad42..56791a8 100644 --- a/tests/integration/training/test_full_integration.py +++ b/tests/integration/training/test_full_integration.py @@ -398,8 +398,13 @@ def create_state_fn(img, key): def compute_estimate_fn(img, state, ts): return state, {"dummy": jnp.array(0.0)} - # Track ReplayBuffer initialization - with patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer: + # Track ReplayBuffer initialization. Stub the Checkpointer: this test + # exercises ReplayBuffer wiring, not checkpointing, and the mocked + # env/sampler can't be serialized by a real orbax save. + with ( + patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer, + patch("train.Checkpointer", side_effect=_FakeCheckpointer), + ): train( estimator=model, estimator_config={"config": {"jit": False}}, @@ -453,8 +458,14 @@ def tracking_train_step(*args, **kwargs): train_step_calls.append((args, kwargs)) return original_train_step(*args, **kwargs) - with patch.object( - model, "create_train_step", return_value=tracking_train_step + # Stub the Checkpointer: this test exercises replay sampling, not + # checkpointing, and the mocked env/sampler can't be serialized by a + # real orbax save. + with ( + patch.object( + model, "create_train_step", return_value=tracking_train_step + ), + patch("train.Checkpointer", side_effect=_FakeCheckpointer), ): train( estimator=model, diff --git a/tests/integration/training/test_train_replay_integration.py b/tests/integration/training/test_train_replay_integration.py index 34e990a..baa7518 100644 --- a/tests/integration/training/test_train_replay_integration.py +++ b/tests/integration/training/test_train_replay_integration.py @@ -17,8 +17,14 @@ def test_train_replay_initialization(mock_dependencies): """ReplayBuffer is instantiated once when the RL train loop starts.""" estimator, env, obs, env_state = mock_dependencies - # We want to check if ReplayBuffer is initialized - with patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer: + # We want to check if ReplayBuffer is initialized. Stub the + # Checkpointer: this test exercises ReplayBuffer wiring, not + # checkpointing, and the mocked deps can't be serialized by a real + # orbax save (the end-of-run save_final would otherwise try). + with ( + patch("train.ReplayBuffer", wraps=ReplayBuffer) as mock_buffer, + patch("train.Checkpointer"), + ): train( estimator=estimator, estimator_config={"config": {"jit": False}}, @@ -43,7 +49,11 @@ def test_train_replay_initialization(mock_dependencies): mock_buffer.assert_called_once() -def test_train_replay_execution(mock_dependencies): +# Stub the Checkpointer: this test exercises replay sampling, not +# checkpointing, and the mocked deps can't be serialized by a real orbax +# save (the end-of-run save_final would otherwise try). +@patch("train.Checkpointer") +def test_train_replay_execution(_mock_checkpointer, mock_dependencies): """replay_ratio=1.0 fires the train step on live and replay data.""" estimator, env, obs, env_state = mock_dependencies train_step_fn = estimator.create_train_step.return_value diff --git a/tests/unit/training/test_checkpointer.py b/tests/unit/training/test_checkpointer.py index 8137d91..633728a 100644 --- a/tests/unit/training/test_checkpointer.py +++ b/tests/unit/training/test_checkpointer.py @@ -62,6 +62,9 @@ def metrics(self, step): def latest_step(self): return self._steps[-1] if self._steps else None + def all_steps(self): + return list(self._steps) + def wait_until_finished(self): self.wait_calls += 1 @@ -473,23 +476,43 @@ def test_save_final_falls_back_to_last_observed(tmp_path): assert fake.save_args[-1]["metrics"] == {"mean_error": 0.3} -def test_orbax_rejects_duplicate_step_returns_none(tmp_path): - """A backward ``save`` that reaches orbax reports None and warns. +def test_orbax_rejects_backward_unsaved_step_returns_none(tmp_path): + """A genuinely backward ``save_final`` reports None and warns. - Uses ``save_final`` for the second call so the in-loop same-step - guard on ``save_periodic`` doesn't short-circuit before orbax. + A step behind ``latest`` that was never written is not on disk, so + it reaches orbax and exercises the duplicate/backward rejection. """ cfg = CheckpointConfig() for ckpt, fake, logger_mock in _make(tmp_path, cfg=cfg): - assert ckpt.save_periodic(MagicMock(), step=10) is not None - # ``save_final`` does not guard same-step so the call reaches - # the fake and exercises orbax's duplicate-step rejection. + assert ckpt.save_periodic(MagicMock(), step=20) is not None + # step 10 < latest 20 and was never saved -> reaches orbax. assert ckpt.save_final(MagicMock(), step=10) is None - assert len(fake.save_args) == 2 - assert fake.save_args[1]["ok"] is False + assert fake.save_args[-1]["ok"] is False logger_mock.warning.assert_called() +def test_save_final_reuses_already_persisted_step(tmp_path): + """Regression for the Codex P2 on PR #51. + + When training ends on a ``save_every`` boundary that ``save_periodic`` + just wrote, ``save_final`` must not re-write the step (no duplicate + orbax write, no spurious warning) but must still return its path and + upload the ``final`` alias so the end-of-run artifact is always + tagged, regardless of ``num_episodes`` / ``save_every`` alignment. + """ + cfg = CheckpointConfig(wandb_upload="every") + for ckpt, fake, logger_mock in _make(tmp_path, cfg=cfg): + assert ckpt.save_periodic(MagicMock(), step=10) is not None + n_saves = len(fake.save_args) + out = ckpt.save_final(MagicMock(), step=10) + assert out is not None + # No second orbax write and no duplicate-step warning. + assert len(fake.save_args) == n_saves + logger_mock.warning.assert_not_called() + # The final alias was still uploaded for the existing checkpoint. + assert "final" in logger_mock.artifact.call_args.args[0]["aliases"] + + def test_save_periodic_same_step_after_on_validation_silently_skips(tmp_path): """Regression for antonioterpin #158 review #4.