Skip to content

Commit fe4942d

Browse files
committed
Implement SAVED_CHECKPOINT event with proper APIs and documentation
1 parent 3b04445 commit fe4942d

File tree

3 files changed

+78
-3
lines changed

3 files changed

+78
-3
lines changed

docs/source/handlers.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,37 @@ Complete list of generic handlers
3434
state_param_scheduler.StateParamScheduler
3535

3636

37+
Checkpoint Events
38+
-----------------
39+
40+
.. versionadded:: 0.5.0
41+
42+
The Checkpoint handler provides a ``SAVED_CHECKPOINT`` event that fires after successful checkpoint saves.
43+
This allows users to attach custom handlers that react to checkpoint operations without manual event registration.
44+
45+
**Usage:**
46+
47+
.. code-block:: python
48+
49+
from ignite.handlers import Checkpoint
50+
from ignite.engine import Engine, Events
51+
52+
# Setup checkpoint handler
53+
checkpoint_handler = Checkpoint(
54+
{'model': model, 'optimizer': optimizer},
55+
save_dir,
56+
n_saved=2
57+
)
58+
59+
# Attach handler to checkpoint event (no manual registration needed)
60+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
61+
def on_checkpoint_saved(engine):
62+
print(f"Checkpoint saved at epoch {engine.state.epoch}")
63+
# Add custom logic: notifications, logging, etc.
64+
65+
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
66+
67+
3768
Loggers
3869
--------
3970

ignite/handlers/checkpoint.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@
2121

2222
import ignite.distributed as idist
2323
from ignite.base import Serializable
24-
from ignite.engine import Engine, Events
24+
from ignite.engine import Engine, Events, EventEnum
2525
from ignite.utils import _tree_apply2, _tree_map
2626

27-
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
27+
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
28+
29+
30+
class CheckpointEvents(EventEnum):
31+
"""Events fired by Checkpoint handler"""
32+
33+
SAVED_CHECKPOINT = "saved_checkpoint"
2834

2935

3036
class BaseSaveHandler(metaclass=ABCMeta):
@@ -276,6 +282,7 @@ class Checkpoint(Serializable):
276282
- `save_on_rank` saves objects on this rank in a distributed configuration.
277283
"""
278284

285+
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
279286
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
280287
_state_dict_all_req_keys = ("_saved",)
281288

@@ -400,6 +407,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
400407
return new > self._saved[0].priority
401408

402409
def __call__(self, engine: Engine) -> None:
410+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
411+
engine.register_events(CheckpointEvents.SAVED_CHECKPOINT)
403412
global_step = None
404413
if self.global_step_transform is not None:
405414
global_step = self.global_step_transform(engine, engine.last_event_name)
@@ -460,11 +469,13 @@ def __call__(self, engine: Engine) -> None:
460469
if self.include_self:
461470
# Now that we've updated _saved, we can add our own state_dict.
462471
checkpoint["checkpointer"] = self.state_dict()
463-
472+
# Store reference to self in engine for event handlers to access
473+
engine._current_checkpoint_handler = self
464474
try:
465475
self.save_handler(checkpoint, filename, metadata)
466476
except TypeError:
467477
self.save_handler(checkpoint, filename)
478+
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
468479

469480
def _setup_checkpoint(self) -> Dict[str, Any]:
470481
if self.to_save is not None:

tests/ignite/handlers/test_checkpoint.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,39 @@ def test_load_single_object(obj_to_save, dirname):
18501850
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))
18511851

18521852

1853+
def test_checkpoint_saved_event():
1854+
"""Test that SAVED_CHECKPOINT event is fired correctly."""
1855+
save_handler = MagicMock(spec=BaseSaveHandler)
1856+
to_save = {"model": DummyModel()}
1857+
1858+
checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2)
1859+
1860+
trainer = Engine(lambda e, b: None)
1861+
trainer.state = State(epoch=0, iteration=0)
1862+
1863+
# Track event firing
1864+
event_count = 0
1865+
1866+
# First, call the checkpoint handler to trigger automatic event registration
1867+
checkpointer(trainer)
1868+
1869+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
1870+
def on_checkpoint_saved(engine):
1871+
nonlocal event_count
1872+
event_count += 1
1873+
1874+
# Verify the first checkpoint didn't trigger our handler (attached after)
1875+
assert event_count == 0
1876+
1877+
# Second checkpoint - should fire event and trigger our handler
1878+
trainer.state.iteration = 1
1879+
checkpointer(trainer)
1880+
assert event_count == 1
1881+
1882+
# Verify save handler was called twice
1883+
assert save_handler.call_count == 2
1884+
1885+
18531886
@pytest.mark.distributed
18541887
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
18551888
@pytest.mark.parametrize("atomic", [False, True])

0 commit comments

Comments
 (0)