2121
2222import ignite .distributed as idist
2323from ignite .base import Serializable
24- from ignite .engine import Engine , Events
24+ from ignite .engine import Engine , Events , EventEnum
2525from ignite .utils import _tree_apply2 , _tree_map
2626
27- __all__ = ["Checkpoint" , "DiskSaver" , "ModelCheckpoint" , "BaseSaveHandler" ]
27+ __all__ = ["Checkpoint" , "DiskSaver" , "ModelCheckpoint" , "BaseSaveHandler" , "CheckpointEvents" ]
28+
29+
30+ class CheckpointEvents (EventEnum ):
31+ """Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`
32+
33+ - SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects
34+
35+ .. versionadded:: 0.5.3
36+ """
37+
38+ SAVED_CHECKPOINT = "saved_checkpoint"
2839
2940
3041class BaseSaveHandler (metaclass = ABCMeta ):
@@ -264,6 +275,29 @@ class Checkpoint(Serializable):
264275 to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
265276 )
266277
278+ Respond to checkpoint events:
279+
280+ .. code-block:: python
281+
282+ from ignite.handlers import Checkpoint
283+ from ignite.engine import Engine, Events
284+
285+ checkpoint_handler = Checkpoint(
286+ {'model': model, 'optimizer': optimizer},
287+ save_dir,
288+ n_saved=2
289+ )
290+
291+ @trainer.on(Checkpoint.SAVED_CHECKPOINT)
292+ def on_checkpoint_saved(engine):
293+ print(f"Checkpoint saved at epoch {engine.state.epoch}")
294+
295+ trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)
296+
297+ Attributes:
298+ SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
299+ :class:`~ignite.handlers.checkpoint.CheckpointEvents`.
300+
267301 .. versionchanged:: 0.4.3
268302
269303 - Checkpoint can save model with same filename.
@@ -274,8 +308,13 @@ class Checkpoint(Serializable):
274308 - `score_name` can be used to define `score_function` automatically without providing `score_function`.
275309 - `save_handler` automatically saves to disk if path to directory is provided.
276310 - `save_on_rank` saves objects on this rank in a distributed configuration.
311+
312+ .. versionchanged:: 0.5.3
313+
314+ - Added ``SAVED_CHECKPOINT`` class attribute.
277315 """
278316
317+ SAVED_CHECKPOINT = CheckpointEvents .SAVED_CHECKPOINT
279318 Item = NamedTuple ("Item" , [("priority" , int ), ("filename" , str )])
280319 _state_dict_all_req_keys = ("_saved" ,)
281320
@@ -400,6 +439,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
400439 return new > self ._saved [0 ].priority
401440
402441 def __call__ (self , engine : Engine ) -> None :
442+ if not engine .has_registered_events (CheckpointEvents .SAVED_CHECKPOINT ):
443+ engine .register_events (* CheckpointEvents )
403444 global_step = None
404445 if self .global_step_transform is not None :
405446 global_step = self .global_step_transform (engine , engine .last_event_name )
@@ -460,11 +501,11 @@ def __call__(self, engine: Engine) -> None:
460501 if self .include_self :
461502 # Now that we've updated _saved, we can add our own state_dict.
462503 checkpoint ["checkpointer" ] = self .state_dict ()
463-
464504 try :
465505 self .save_handler (checkpoint , filename , metadata )
466506 except TypeError :
467507 self .save_handler (checkpoint , filename )
508+ engine .fire_event (CheckpointEvents .SAVED_CHECKPOINT )
468509
469510 def _setup_checkpoint (self ) -> Dict [str , Any ]:
470511 if self .to_save is not None :
0 commit comments