Skip to content

fix(checkpoint): RL replay tests + keep final alias (follow-up to #51)#52

Open
fbanelli wants to merge 1 commit into
devfrom
fix/checkpointer-rl-final-save
Open

fix(checkpoint): RL replay tests + keep final alias (follow-up to #51)#52
fbanelli wants to merge 1 commit into
devfrom
fix/checkpointer-rl-final-save

Conversation

@fbanelli
Copy link
Copy Markdown
Collaborator

Follow-up to the now-merged #51. A regression in the always-on RL save_final (added in #51) was found while reviewing the private backport; this lands the fix on dev properly (it was mistakenly pushed onto the already-merged #51 branch first — that commit is not on dev).

Changes

  • RL replay tests (test_train_rl_replay_buffer_used, test_train_rl_replay_buffer_samples, test_train_replay_initialization, test_train_replay_execution) drove the real Checkpointer with a mocked sampler/model/state. The always-on save_final then forced a real orbax save that serialized a MagicMock (grain iterator / state extras) and failed. These tests cover ReplayBuffer, not checkpointing, so they now patch train.Checkpointer (as test_train_rl_full_integration already does).
  • save_final: when the final step is already on disk (training ended on a save_every boundary that save_periodic just wrote), reuse the existing checkpoint instead of letting orbax reject a duplicate write — which dropped the final W&B alias and logged a spurious warning. The final artifact is now always tagged.
  • tests: replace the duplicate-step test with a backward-unsaved-step case plus a regression for the reused-final-step path; add all_steps to the fake manager.

Test plan

  • tests/unit/training/test_checkpointer.py + tests/integration/training/{test_full_integration,test_train_replay_integration}.py — 54 passed
  • ruff 0.14.1 / pydoclint / basedpyright clean on changed files

🤖 Generated with Claude Code

…keep final alias

Addresses the second review round on PR #51.

- RL replay integration tests (`test_train_rl_replay_buffer_used`,
  `test_train_rl_replay_buffer_samples`, `test_train_replay_initialization`,
  `test_train_replay_execution`) drove the real `Checkpointer` with a
  mocked sampler/model/state. The new always-on `save_final` then forced
  a real orbax save that serialized a `MagicMock` (grain iterator / state
  `extras`) and failed. These tests cover `ReplayBuffer`, not
  checkpointing, so they now patch `train.Checkpointer`.
- `save_final`: when the final step is already on disk (training ended on
  a `save_every` boundary that `save_periodic` just wrote), reuse the
  existing checkpoint instead of letting orbax reject a duplicate write —
  which dropped the `final` W&B alias and logged a spurious warning.
  Fixes the Codex P2; the `final` artifact is now always tagged.
- tests: replace the duplicate-step test with a backward-unsaved-step
  case plus a regression for the reused-final-step path; add `all_steps`
  to the fake manager.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant