Skip to content
31 changes: 31 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,37 @@ Complete list of generic handlers
state_param_scheduler.StateParamScheduler


Checkpoint Events
-----------------

.. versionadded:: 0.5.0

The Checkpoint handler provides a ``SAVED_CHECKPOINT`` event that fires after successful checkpoint saves.
This allows users to attach custom handlers that react to checkpoint operations without manual event registration.

**Usage:**

.. code-block:: python

from ignite.handlers import Checkpoint
from ignite.engine import Engine, Events

# Setup checkpoint handler
checkpoint_handler = Checkpoint(
{'model': model, 'optimizer': optimizer},
save_dir,
n_saved=2
)

# Attach handler to checkpoint event (no manual registration needed)
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
def on_checkpoint_saved(engine):
print(f"Checkpoint saved at epoch {engine.state.epoch}")
# Add custom logic: notifications, logging, etc.

trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)


Loggers
--------

Expand Down
15 changes: 12 additions & 3 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@

import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events
from ignite.engine import Engine, Events, EventEnum
from ignite.utils import _tree_apply2, _tree_map

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]


class CheckpointEvents(EventEnum):
"""Events fired by Checkpoint handler"""

SAVED_CHECKPOINT = "saved_checkpoint"


class BaseSaveHandler(metaclass=ABCMeta):
Expand Down Expand Up @@ -276,6 +282,7 @@ class Checkpoint(Serializable):
- `save_on_rank` saves objects on this rank in a distributed configuration.
"""

SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
_state_dict_all_req_keys = ("_saved",)

Expand Down Expand Up @@ -400,6 +407,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
return new > self._saved[0].priority

def __call__(self, engine: Engine) -> None:
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
engine.register_events(CheckpointEvents.SAVED_CHECKPOINT)
global_step = None
if self.global_step_transform is not None:
global_step = self.global_step_transform(engine, engine.last_event_name)
Expand Down Expand Up @@ -460,11 +469,11 @@ def __call__(self, engine: Engine) -> None:
if self.include_self:
# Now that we've updated _saved, we can add our own state_dict.
checkpoint["checkpointer"] = self.state_dict()

try:
self.save_handler(checkpoint, filename, metadata)
except TypeError:
self.save_handler(checkpoint, filename)
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)

def _setup_checkpoint(self) -> Dict[str, Any]:
if self.to_save is not None:
Expand Down
33 changes: 33 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,39 @@ def test_load_single_object(obj_to_save, dirname):
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))


def test_checkpoint_saved_event():
"""Test that SAVED_CHECKPOINT event is fired correctly."""
save_handler = MagicMock(spec=BaseSaveHandler)
to_save = {"model": DummyModel()}

checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2)

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)

# Track event firing
event_count = 0

# First, call the checkpoint handler to trigger automatic event registration
checkpointer(trainer)

@trainer.on(Checkpoint.SAVED_CHECKPOINT)
def on_checkpoint_saved(engine):
nonlocal event_count
event_count += 1

# Verify the first checkpoint didn't trigger our handler (attached after)
assert event_count == 0

# Second checkpoint - should fire event and trigger our handler
trainer.state.iteration = 1
checkpointer(trainer)
assert event_count == 1

# Verify save handler was called twice
assert save_handler.call_count == 2


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.parametrize("atomic", [False, True])
Expand Down
Loading