Skip to content

Commit c226026

Browse files
cpgaffney1Orbax Authors
authored andcommitted
#v1 Shorten naming for publicly-exposed save_decision_policies and preservation_policies.
PiperOrigin-RevId: 878158437
1 parent 7fc5266 commit c226026

File tree

8 files changed

+55
-58
lines changed

8 files changed

+55
-58
lines changed

checkpoint/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- #v1 Deleted `CompositeHandler` and refactored its functionality directly into
1818
`OrbaxLayout` internal handler resolution logic.
1919

20+
### Changed
21+
22+
- #v1 Shorten naming for publicly-exposed save_decision_policies and preservation_policies.
23+
2024
## [0.11.33] - 2025-02-17
2125

2226
### Added

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,17 @@ def __init__(
9696
of checkpoints are saved at regular intervals. Example usage::
9797
9898
# Configure the frequency at which checkpoints are saved.
99-
save_decision_policies = ocp.training.save_decision_policies
100-
# Save every 1000 steps, or when a preemption is detected.
101-
save_decision_policy = save_decision_policies.AnySavePolicy([
102-
save_decision_policies.FixedIntervalPolicy(1000),
103-
save_decision_policies.PreemptionPolicy(),
99+
save_decision = ocp.training.save_decision
100+
save_decision_policy = save_decision.AnySavePolicy([
101+
save_decision.FixedIntervalPolicy(1000),
102+
save_decision.PreemptionPolicy(),
104103
])
105104
106105
# Configure the checkpoints to preserve (avoid garbage collection).
107-
preservation_policies = ocp.training.preservation_policies
108-
# Avoid garbage collection on the latest 10, or every 10000 steps.
109-
preservation_policy = preservation_policies.AnyPreservationPolicy([
110-
preservation_policies.LatestN(10),
111-
preservation_policies.EveryNSteps(10000),
106+
preservation = ocp.training.preservation
107+
preservation_policy = preservation.AnyPreservationPolicy([
108+
preservation.LatestN(10),
109+
preservation.EveryNSteps(10000),
112110
])
113111
114112
with ocp.training.Checkpointer(
@@ -150,6 +148,7 @@ def __init__(
150148
metadata. This should be information that is relevant to the entire
151149
sequence of checkpoints, rather than to any single checkpoint.
152150
"""
151+
153152
context = context_lib.get_context()
154153

155154
default_save_decision_policy = save_decision_policies.AnySavePolicy([

checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
3535

3636
Checkpointer = ocp.training.Checkpointer
37-
save_decision_policies = ocp.training.save_decision_policies
38-
preservation_policies = ocp.training.preservation_policies
37+
save_decision = ocp.training.save_decision
38+
preservation = ocp.training.preservation
3939

4040
RootMetadata = ocp.training.RootMetadata
4141
CheckpointMetadata = ocp.training.CheckpointMetadata
@@ -177,20 +177,20 @@ def test_load_non_existent_step(self):
177177
checkpointer.pytree_metadata(1)
178178

179179
@parameterized.parameters(
180-
(save_decision_policies.ContinuousCheckpointingPolicy(), range(10)),
181-
(save_decision_policies.FixedIntervalPolicy(1), range(10)),
182-
(save_decision_policies.FixedIntervalPolicy(2), range(0, 10, 2)),
180+
(save_decision.ContinuousCheckpointingPolicy(), range(10)),
181+
(save_decision.FixedIntervalPolicy(1), range(10)),
182+
(save_decision.FixedIntervalPolicy(2), range(0, 10, 2)),
183183
(
184-
save_decision_policies.AnySavePolicy([
185-
save_decision_policies.SpecificStepsPolicy((2, 3, 6)),
186-
save_decision_policies.InitialSavePolicy(),
184+
save_decision.AnySavePolicy([
185+
save_decision.SpecificStepsPolicy((2, 3, 6)),
186+
save_decision.InitialSavePolicy(),
187187
]),
188188
[0, 2, 3, 6],
189189
),
190190
)
191191
def test_steps(
192192
self,
193-
policy: save_decision_policies.SaveDecisionPolicy,
193+
policy: save_decision.SaveDecisionPolicy,
194194
expected_steps: Sequence[int],
195195
):
196196
num_steps = 10
@@ -211,7 +211,7 @@ def test_steps(
211211
def test_force_save_ignores_save_decision_policy(self):
212212
checkpointer = Checkpointer(
213213
self.directory,
214-
save_decision_policy=save_decision_policies.FixedIntervalPolicy(2),
214+
save_decision_policy=save_decision.FixedIntervalPolicy(2),
215215
)
216216
self.enter_context(checkpointer)
217217

@@ -522,14 +522,14 @@ def test_different_custom_checkpointables(self):
522522
def test_custom_save_decision_policy(self):
523523
save_delta = datetime.timedelta(seconds=0.03)
524524

525-
class ArbitrarySavePolicy(save_decision_policies.SaveDecisionPolicy):
525+
class ArbitrarySavePolicy(save_decision.SaveDecisionPolicy):
526526

527527
def should_save(
528528
self,
529529
step: CheckpointMetadata,
530530
previous_steps: Sequence[CheckpointMetadata],
531531
*,
532-
context: save_decision_policies.DecisionContext,
532+
context: save_decision.DecisionContext,
533533
) -> bool:
534534
save_result = False
535535
is_primary_host = multihost.is_primary_host(
@@ -569,14 +569,14 @@ def should_save(
569569

570570
@parameterized.parameters(
571571
(None, range(10)),
572-
(ocp.training.preservation_policies.PreserveAll(), range(10)),
573-
(preservation_policies.LatestN(3), range(7, 10)),
574-
(preservation_policies.EveryNSteps(4), [0, 4, 8]),
575-
(preservation_policies.EveryNSeconds(40), range(0, 10, 2)),
572+
(ocp.training.preservation.PreserveAll(), range(10)),
573+
(preservation.LatestN(3), range(7, 10)),
574+
(preservation.EveryNSteps(4), [0, 4, 8]),
575+
(preservation.EveryNSeconds(40), range(0, 10, 2)),
576576
(
577-
preservation_policies.AnyPreservationPolicy([
578-
preservation_policies.LatestN(3),
579-
preservation_policies.CustomSteps([1, 3, 9]),
577+
preservation.AnyPreservationPolicy([
578+
preservation.LatestN(3),
579+
preservation.CustomSteps([1, 3, 9]),
580580
]),
581581
[1, 3, 7, 8, 9],
582582
),
@@ -619,27 +619,21 @@ def now(cls, tz=None):
619619

620620
@parameterized.parameters(
621621
(
622-
preservation_policies.BestN(
623-
get_metric_fn=lambda m: m['accuracy'], n=2
624-
),
622+
preservation.BestN(get_metric_fn=lambda m: m['accuracy'], n=2),
625623
[0, 4],
626624
),
627625
(
628-
preservation_policies.BestN(
626+
preservation.BestN(
629627
get_metric_fn=lambda m: m['loss'], reverse=True, n=2
630628
),
631629
[0, 1],
632630
),
633631
(
634-
preservation_policies.BestN(
635-
get_metric_fn=lambda m: m['accuracy'], n=None
636-
),
632+
preservation.BestN(get_metric_fn=lambda m: m['accuracy'], n=None),
637633
range(5),
638634
),
639635
(
640-
preservation_policies.BestN(
641-
get_metric_fn=lambda m: m['accuracy'], n=0
642-
),
636+
preservation.BestN(get_metric_fn=lambda m: m['accuracy'], n=0),
643637
[],
644638
),
645639
)

checkpoint/orbax/checkpoint/experimental/v1/_src/training/save_decision_policies.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class SaveDecisionPolicy(Protocol):
6565
save frequency. For example::
6666
6767
from orbax.checkpoint.experimental.v1 import training
68-
policies = training.save_decision_policies
68+
policies = training.save_decision
6969
7070
# Save every 1000 steps, or when a preemption is detected.
7171
policy = policies.AnySavePolicy([
@@ -102,9 +102,9 @@ def should_save(
102102
containing the step index, timestamp, and metadata.
103103
previous_steps (Sequence[CheckpointMetadata]): A chronological list of
104104
metadata for all steps where a checkpoint was successfully saved.
105-
context (DecisionContext): A container for auxiliary information,
106-
such as validation loss or performance metrics, used to inform the
107-
save decision.
105+
context (DecisionContext): A container for auxiliary information, such
106+
as validation loss or performance metrics, used to inform the save
107+
decision.
108108
"""
109109

110110
def should_save(

checkpoint/orbax/checkpoint/experimental/v1/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
Checkpointer,
2121
)
2222

23-
from orbax.checkpoint.experimental.v1._src.training import save_decision_policies
24-
from orbax.checkpoint.experimental.v1._src.training import preservation_policies
23+
from orbax.checkpoint.experimental.v1._src.training import save_decision_policies as save_decision
24+
from orbax.checkpoint.experimental.v1._src.training import preservation_policies as preservation
2525
from orbax.checkpoint.experimental.v1._src.training import errors
2626

2727
from orbax.checkpoint.experimental.v1._src.training.metadata.types import (

docs/api_reference/checkpoint.v1.training.save_decision_policies.rst renamed to docs/api_reference/checkpoint.v1.training.save_decision.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
Training Save Decision Policies
22
============================================================================
33

4-
.. currentmodule:: orbax.checkpoint.experimental.v1.training.save_decision_policies
4+
.. currentmodule:: orbax.checkpoint.experimental.v1.training.save_decision
55

6-
.. automodule:: orbax.checkpoint.experimental.v1.training.save_decision_policies
6+
.. automodule:: orbax.checkpoint.experimental.v1.training.save_decision
77

88

99
SaveDecisionPolicy

docs/guides/checkpoint/v1/checkpointing_and_exporting_jax_models.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@
417417
"\n",
418418
"with ocp.training.Checkpointer(\n",
419419
" directory=str(training_ckpt_dir),\n",
420-
" save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(SAVE_INTERVAL_STEPS)\n",
420+
" save_decision_policy=ocp.training.save_decision.FixedIntervalPolicy(SAVE_INTERVAL_STEPS)\n",
421421
") as ckptr:\n",
422422
" for _ in range(num_training_steps):\n",
423423
" step_to_save_at = current_loop_state['step']\n",

docs/guides/checkpoint/v1/training.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@
289289
},
290290
{
291291
"cell_type": "code",
292-
"execution_count": 8,
292+
"execution_count": null,
293293
"metadata": {
294294
"id": "tGzopEmVfXfr"
295295
},
@@ -299,7 +299,7 @@
299299
"root_directory.rmtree(missing_ok=True)\n",
300300
"with training.Checkpointer(\n",
301301
" root_directory,\n",
302-
" save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),\n",
302+
" save_decision_policy=training.save_decision.FixedIntervalPolicy(3),\n",
303303
") as ckptr:\n",
304304
" for step in range(10):\n",
305305
" ckptr.save_pytree(step, pytree)"
@@ -341,7 +341,7 @@
341341
},
342342
{
343343
"cell_type": "code",
344-
"execution_count": 10,
344+
"execution_count": null,
345345
"metadata": {
346346
"id": "YKyO0Ak1fXfr"
347347
},
@@ -351,7 +351,7 @@
351351
"root_directory.rmtree(missing_ok=True)\n",
352352
"with training.Checkpointer(\n",
353353
" root_directory,\n",
354-
" save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),\n",
354+
" save_decision_policy=training.save_decision.FixedIntervalPolicy(3),\n",
355355
" custom_metadata={'experiment_name': 'my-experiment'},\n",
356356
") as ckptr:\n",
357357
" num_steps = 10\n",
@@ -638,17 +638,17 @@
638638
},
639639
{
640640
"cell_type": "code",
641-
"execution_count": 20,
641+
"execution_count": null,
642642
"metadata": {
643643
"id": "ovozUMzx1zd7"
644644
},
645645
"outputs": [],
646646
"source": [
647647
"with training.Checkpointer(\n",
648648
" root_directory,\n",
649-
" preservation_policy=training.preservation_policies.AnyPreservationPolicy([\n",
650-
" training.preservation_policies.LatestN(2),\n",
651-
" training.preservation_policies.EveryNSteps(4),\n",
649+
" preservation_policy=training.preservation.AnyPreservationPolicy([\n",
650+
" training.preservation.LatestN(2),\n",
651+
" training.preservation.EveryNSteps(4),\n",
652652
" ]),\n",
653653
") as ckptr:\n",
654654
" print([c.step for c in ckptr.checkpoints])\n",
@@ -1094,7 +1094,7 @@
10941094
},
10951095
{
10961096
"cell_type": "code",
1097-
"execution_count": 32,
1097+
"execution_count": null,
10981098
"metadata": {
10991099
"id": "BKdtZC2nx0GF"
11001100
},
@@ -1139,7 +1139,7 @@
11391139
" # Otherwise, load the latest checkpoint; if none exists, start from scratch.\n",
11401140
" with training.Checkpointer(\n",
11411141
" root_directory,\n",
1142-
" save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(\n",
1142+
" save_decision_policy=training.save_decision.FixedIntervalPolicy(\n",
11431143
" save_interval\n",
11441144
" ),\n",
11451145
" ) as ckptr:\n",

0 commit comments

Comments
 (0)