3131
3232@dataclasses .dataclass (kw_only = True )
3333class DecisionContext :
34- """Additional properties for making a save decision."""
34+ """Additional properties for making a save decision.
35+
36+ This dataclass is populated by the checkpointer framework and passed into the
37+ `should_save` method of all `SaveDecisionPolicy` implementations. It provides
38+ essential external system context, allowing policies to make safe, state-aware
39+ decisions.
40+
41+ Attributes:
42+ is_saving_in_progress: Indicates whether an asynchronous checkpoint
43+ save operation is currently running in the background. Policies (like
44+ `ContinuousCheckpointingPolicy`) use this to avoid triggering overlapping
45+ save operations.
46+ reached_preemption: Indicates whether a preemption signal has been
47+ received from the cluster manager, meaning the training job is about to
48+ be terminated.
49+ multiprocessing_options: Configuration details for distributed multihost
50+ training. This provides information such as primary host identification
51+ for synchronization barriers.
52+ """
3553
3654 is_saving_in_progress : bool
3755 reached_preemption : bool
@@ -166,7 +184,8 @@ class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
166184 """Checkpoint as often as possible, as long as a save is not ongoing.
167185
168186 This policy evaluates to True as often as possible. It enforces two primary
169- constraints to prevent blocking training or causing other regressions.:
187+ constraints to prevent blocking training or causing other regressions.
188+
170189 1. It will never trigger a new save if a save is currently in progress
171190 (checked via the provided `DecisionContext`); this prevents blocking on an
172191 ongoing save, which would hurt accelerator utilization.
@@ -254,7 +273,31 @@ def _get_should_save_result() -> bool:
254273
255274
256275class PreemptionCheckpointingPolicy (SaveDecisionPolicy ):
257- """Save a checkpoint when a preemption is detected."""
276+ """Save a checkpoint when a preemption is detected.
277+
278+ This policy evaluates to True strictly when the provided `DecisionContext`
279+ indicates that a preemption signal has been received (i.e.,
280+ `context.reached_preemption` is True). It can be useful for ensuring that
281+ training progress is safely stored before the job is killed by the cluster
282+ scheduler.
283+
284+ Note that saving on preemption is not strictly necessary, however. For
285+ example, if continuous checkpointing is employed, and checkpoints are saved
286+ frequently enough, the cost of re-computing some amount of steps can be
287+ cheaper than the cost of waiting for a checkpoint to complete after
288+ preemption.
289+
290+ Methods:
291+ should_save(step, previous_steps, *, context):
292+ Evaluates whether a preemption signal has been registered in the context.
293+
294+ Args:
295+ step (PolicyCheckpointInfo): Ignored by this policy.
296+ previous_steps (Sequence[PolicyCheckpointInfo]): Ignored by this policy.
297+ context (DecisionContext): A container for auxiliary information. This
298+ policy specifically checks the `reached_preemption` boolean flag to
299+ make its decision.
300+ """
258301
259302 def should_save (
260303 self ,
@@ -274,7 +317,24 @@ def should_save(
274317
275318
276319class InitialSavePolicy (SaveDecisionPolicy ):
277- """Checkpoint as soon as possible if no checkpoints already exist."""
320+ """Save a checkpoint as soon as possible if no checkpoints already exist.
321+
322+ This policy evaluates to True only if the `previous_steps` sequence is empty.
323+ It is highly useful for ensuring a baseline checkpoint is created immediately
324+ upon starting a fresh training run, while safely evaluating to False if the
325+ job is restarting from an existing checkpoint.
326+
327+ Methods:
328+ should_save(step, previous_steps, *, context):
329+ Evaluates whether the `previous_steps` history is empty.
330+
331+ Args:
332+ step (PolicyCheckpointInfo): Ignored by this policy.
333+ previous_steps (Sequence[PolicyCheckpointInfo]): A chronological list
334+ of metadata for previously saved steps. The policy checks if this is
335+ empty.
336+ context (DecisionContext): Ignored by this policy.
337+ """
278338
279339 def should_save (
280340 self ,
@@ -297,8 +357,25 @@ def should_save(
297357class AnySavePolicy (SaveDecisionPolicy ):
298358 """Evaluates all policies and saves if any of them returns True.
299359
300- Each policy is evaluated in order, and if all are False, the final result is
301- False. If at least one is True, the final result is True.
360+ This policy iterates through a provided sequence of child policies. It
361+ evaluates each one in order and returns True immediately if any child policy
362+ returns True. If all child policies return False, this policy returns False.
363+ It is highly useful for combining time-based, step-based, and event-based
364+ saving rules into a single, unified checkpointer configuration.
365+
366+ Attributes:
367+ policies (Sequence[SaveDecisionPolicy]): An ordered collection of underlying
368+ policies to evaluate.
369+
370+ Methods:
371+ should_save(step, previous_steps, *, context):
372+ Evaluates the sequence of configured policies.
373+
374+ Args:
375+ step (PolicyCheckpointInfo): Passed down to each child policy.
376+ previous_steps (Sequence[PolicyCheckpointInfo]): Passed down to each
377+ child policy.
378+ context (DecisionContext): Passed down to each child policy.
302379 """
303380
304381 policies : Sequence [SaveDecisionPolicy ]
0 commit comments