Skip to content

Commit b192e7f

Browse files
author
Orbax Authors
committed
Improve InitialSavPolicy, AnySavePolicy, PreemptionCheckpointingPolicy and DecisionContext class docstrings
PiperOrigin-RevId: 876055972
1 parent db84474 commit b192e7f

File tree

1 file changed

+83
-6
lines changed

1 file changed

+83
-6
lines changed

checkpoint/orbax/checkpoint/_src/checkpoint_managers/save_decision_policy.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,25 @@
3131

3232
@dataclasses.dataclass(kw_only=True)
3333
class DecisionContext:
34-
"""Additional properties for making a save decision."""
34+
"""Additional properties for making a save decision.
35+
36+
This dataclass is populated by the checkpointer framework and passed into the
37+
`should_save` method of all `SaveDecisionPolicy` implementations. It provides
38+
essential external system context, allowing policies to make safe, state-aware
39+
decisions.
40+
41+
Attributes:
42+
is_saving_in_progress: Indicates whether an asynchronous checkpoint
43+
save operation is currently running in the background. Policies (like
44+
`ContinuousCheckpointingPolicy`) use this to avoid triggering overlapping
45+
save operations.
46+
reached_preemption: Indicates whether a preemption signal has been
47+
received from the cluster manager, meaning the training job is about to
48+
be terminated.
49+
multiprocessing_options: Configuration details for distributed multihost
50+
training. This provides information such as primary host identification
51+
for synchronization barriers.
52+
"""
3553

3654
is_saving_in_progress: bool
3755
reached_preemption: bool
@@ -166,7 +184,8 @@ class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
166184
"""Checkpoint as often as possible, as long as a save is not ongoing.
167185
168186
This policy evaluates to True as often as possible. It enforces two primary
169-
constraints to prevent blocking training or causing other regressions.:
187+
constraints to prevent blocking training or causing other regressions.
188+
170189
1. It will never trigger a new save if a save is currently in progress
171190
(checked via the provided `DecisionContext`); this prevents blocking on an
172191
ongoing save, which would hurt accelerator utilization.
@@ -254,7 +273,31 @@ def _get_should_save_result() -> bool:
254273

255274

256275
class PreemptionCheckpointingPolicy(SaveDecisionPolicy):
257-
"""Save a checkpoint when a preemption is detected."""
276+
"""Save a checkpoint when a preemption is detected.
277+
278+
This policy evaluates to True strictly when the provided `DecisionContext`
279+
indicates that a preemption signal has been received (i.e.,
280+
`context.reached_preemption` is True). It can be useful for ensuring that
281+
training progress is safely stored before the job is killed by the cluster
282+
scheduler.
283+
284+
Note that saving on preemption is not strictly necessary, however. For
285+
example, if continuous checkpointing is employed, and checkpoints are saved
286+
frequently enough, the cost of re-computing some amount of steps can be
287+
cheaper than the cost of waiting for a checkpoint to complete after
288+
preemption.
289+
290+
Methods:
291+
should_save(step, previous_steps, *, context):
292+
Evaluates whether a preemption signal has been registered in the context.
293+
294+
Args:
295+
step (PolicyCheckpointInfo): Ignored by this policy.
296+
previous_steps (Sequence[PolicyCheckpointInfo]): Ignored by this policy.
297+
context (DecisionContext): A container for auxiliary information. This
298+
policy specifically checks the `reached_preemption` boolean flag to
299+
make its decision.
300+
"""
258301

259302
def should_save(
260303
self,
@@ -274,7 +317,24 @@ def should_save(
274317

275318

276319
class InitialSavePolicy(SaveDecisionPolicy):
277-
"""Checkpoint as soon as possible if no checkpoints already exist."""
320+
"""Save a checkpoint as soon as possible if no checkpoints already exist.
321+
322+
This policy evaluates to True only if the `previous_steps` sequence is empty.
323+
It is highly useful for ensuring a baseline checkpoint is created immediately
324+
upon starting a fresh training run, while safely evaluating to False if the
325+
job is restarting from an existing checkpoint.
326+
327+
Methods:
328+
should_save(step, previous_steps, *, context):
329+
Evaluates whether the `previous_steps` history is empty.
330+
331+
Args:
332+
step (PolicyCheckpointInfo): Ignored by this policy.
333+
previous_steps (Sequence[PolicyCheckpointInfo]): A chronological list
334+
of metadata for previously saved steps. The policy checks if this is
335+
empty.
336+
context (DecisionContext): Ignored by this policy.
337+
"""
278338

279339
def should_save(
280340
self,
@@ -297,8 +357,25 @@ def should_save(
297357
class AnySavePolicy(SaveDecisionPolicy):
298358
"""Evaluates all policies and saves if any of them returns True.
299359
300-
Each policy is evaluated in order, and if all are False, the final result is
301-
False. If at least one is True, the final result is True.
360+
This policy iterates through a provided sequence of child policies. It
361+
evaluates each one in order and returns True immediately if any child policy
362+
returns True. If all child policies return False, this policy returns False.
363+
It is highly useful for combining time-based, step-based, and event-based
364+
saving rules into a single, unified checkpointer configuration.
365+
366+
Attributes:
367+
policies (Sequence[SaveDecisionPolicy]): An ordered collection of underlying
368+
policies to evaluate.
369+
370+
Methods:
371+
should_save(step, previous_steps, *, context):
372+
Evaluates the sequence of configured policies.
373+
374+
Args:
375+
step (PolicyCheckpointInfo): Passed down to each child policy.
376+
previous_steps (Sequence[PolicyCheckpointInfo]): Passed down to each
377+
child policy.
378+
context (DecisionContext): Passed down to each child policy.
302379
"""
303380

304381
policies: Sequence[SaveDecisionPolicy]

0 commit comments

Comments
 (0)