Skip to content

Commit d54f710

Browse files
Add SAVED_CHECKPOINT event to Checkpoint handler (#3440)
Fixes #934 This PR adds a "saved_checkpoint" event that fires after successful checkpoint saves. **Usage:** ```python # Users need to register the event first engine.register_events("saved_checkpoint") @trainer.on("saved_checkpoint") def after_saving(engine): print("Checkpoint saved!") --------- Co-authored-by: vfdev <[email protected]>
1 parent dcac448 commit d54f710

File tree

3 files changed

+78
-3
lines changed

3 files changed

+78
-3
lines changed

docs/source/handlers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Complete list of generic handlers
1111
:toctree: generated
1212

1313
checkpoint.Checkpoint
14+
checkpoint.CheckpointEvents
1415
DiskSaver
1516
checkpoint.ModelCheckpoint
1617
ema_handler.EMAHandler

ignite/handlers/checkpoint.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,21 @@
2121

2222
import ignite.distributed as idist
2323
from ignite.base import Serializable
24-
from ignite.engine import Engine, Events
24+
from ignite.engine import Engine, Events, EventEnum
2525
from ignite.utils import _tree_apply2, _tree_map
2626

27-
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
27+
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
28+
29+
30+
class CheckpointEvents(EventEnum):
31+
"""Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
32+
33+
- SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
34+
35+
.. versionadded:: 0.5.3
36+
"""
37+
38+
SAVED_CHECKPOINT = "saved_checkpoint"
2839

2940

3041
class BaseSaveHandler(metaclass=ABCMeta):
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
264275
to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
265276
)
266277
278+
Respond to checkpoint events:
279+
280+
.. code-block:: python
281+
282+
from ignite.handlers import Checkpoint
283+
from ignite.engine import Engine, Events
284+
285+
checkpoint_handler = Checkpoint(
286+
{'model': model, 'optimizer': optimizer},
287+
save_dir,
288+
n_saved=2
289+
)
290+
291+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
292+
def on_checkpoint_saved(engine):
293+
print(f"Checkpoint saved at epoch {engine.state.epoch}")
294+
295+
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
296+
297+
Attributes:
298+
SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
299+
:class:`~ignite.handlers.checkpoint.CheckpointEvents`.
300+
267301
.. versionchanged:: 0.4.3
268302
269303
- Checkpoint can save model with same filename.
@@ -274,8 +308,13 @@ class Checkpoint(Serializable):
274308
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
275309
- `save_handler` automatically saves to disk if path to directory is provided.
276310
- `save_on_rank` saves objects on this rank in a distributed configuration.
311+
312+
.. versionchanged:: 0.5.3
313+
314+
- Added ``SAVED_CHECKPOINT`` class attribute.
277315
"""
278316

317+
SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
279318
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
280319
_state_dict_all_req_keys = ("_saved",)
281320

@@ -400,6 +439,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
400439
return new > self._saved[0].priority
401440

402441
def __call__(self, engine: Engine) -> None:
442+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
443+
engine.register_events(*CheckpointEvents)
403444
global_step = None
404445
if self.global_step_transform is not None:
405446
global_step = self.global_step_transform(engine, engine.last_event_name)
@@ -460,11 +501,11 @@ def __call__(self, engine: Engine) -> None:
460501
if self.include_self:
461502
# Now that we've updated _saved, we can add our own state_dict.
462503
checkpoint["checkpointer"] = self.state_dict()
463-
464504
try:
465505
self.save_handler(checkpoint, filename, metadata)
466506
except TypeError:
467507
self.save_handler(checkpoint, filename)
508+
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)
468509

469510
def _setup_checkpoint(self) -> Dict[str, Any]:
470511
if self.to_save is not None:

tests/ignite/handlers/test_checkpoint.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,39 @@ def test_load_single_object(obj_to_save, dirname):
18501850
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))
18511851

18521852

1853+
def test_checkpoint_saved_event():
1854+
"""Test that SAVED_CHECKPOINT event is fired correctly."""
1855+
save_handler = MagicMock(spec=BaseSaveHandler)
1856+
to_save = {"model": DummyModel()}
1857+
1858+
checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2)
1859+
1860+
trainer = Engine(lambda e, b: None)
1861+
trainer.state = State(epoch=0, iteration=0)
1862+
1863+
# Track event firing
1864+
event_count = 0
1865+
1866+
# First, call the checkpoint handler to trigger automatic event registration
1867+
checkpointer(trainer)
1868+
1869+
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
1870+
def on_checkpoint_saved(engine):
1871+
nonlocal event_count
1872+
event_count += 1
1873+
1874+
# Verify the first checkpoint didn't trigger our handler (attached after)
1875+
assert event_count == 0
1876+
1877+
# Second checkpoint - should fire event and trigger our handler
1878+
trainer.state.iteration = 1
1879+
checkpointer(trainer)
1880+
assert event_count == 1
1881+
1882+
# Verify save handler was called twice
1883+
assert save_handler.call_count == 2
1884+
1885+
18531886
@pytest.mark.distributed
18541887
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
18551888
@pytest.mark.parametrize("atomic", [False, True])

0 commit comments

Comments
 (0)