Skip to content

Add checkpoint callback support to NLE.fit()#57

Open
lauradriscoll wants to merge 6 commits intodirmeier:mainfrom
lauradriscoll:add-checkpoint-callback
Open

Add checkpoint callback support to NLE.fit()#57
lauradriscoll wants to merge 6 commits intodirmeier:mainfrom
lauradriscoll:add-checkpoint-callback

Conversation

@lauradriscoll
Copy link
Contributor

Summary

Adds optional checkpoint_callback parameter to NLE.fit() to enable periodic checkpointing and real-time logging during training.

Motivation

  • Long training runs need checkpointing for crash recovery
  • Real-time monitoring with tools like Weights & Biases requires periodic callbacks
  • Users want to save optimizer state to resume training exactly

Changes

  • Added checkpoint_callback parameter (optional callable) to fit() method
  • Added checkpoint_every parameter (default: 100 iterations)
  • Callback receives: iteration, params, train_loss, val_loss, state
  • Fully backward compatible (defaults to None)

Usage Example

import pickle
from pathlib import Path

def save_checkpoint(iteration, params, train_loss, val_loss, state):
    checkpoint = {
        'iteration': iteration,
        'params': params,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'state': state
    }
    Path("checkpoints").mkdir(exist_ok=True)
    with open(f"checkpoints/ckpt_{iteration}.pkl", 'wb') as f:
        pickle.dump(checkpoint, f)
    print(f"Saved checkpoint at iteration {iteration}")

# Use with NLE
params, losses = model.fit(
    rng_key,
    data,
    checkpoint_callback=save_checkpoint,
    checkpoint_every=500
)

Testing

  • Backward compatible (existing code works without changes)
  • Callback is optional (defaults to None)
  • Add unit test (named nle_test_checkpoint.py)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant