Skip to content

Commit 99bfb4b

Browse files
author
Orbax Authors
committed
Improve Context class docstring
PiperOrigin-RevId: 877210490
1 parent b9de9ab commit 99bfb4b

File tree

1 file changed

+43
-21
lines changed
  • checkpoint/orbax/checkpoint/experimental/v1/_src/context

1 file changed

+43
-21
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,51 @@ def get_context(default: Context | None = None) -> Context:
4444
class Context(epy.ContextManager):
4545
"""Context for customized checkpointing.
4646
47-
Usage example::
48-
49-
with ocp.Context(...):
50-
ocp.save_pytree(...)
47+
This class manages the configuration options (e.g., async, multiprocessing,
48+
array handling) used during Orbax checkpoint operations.
5149
5250
Creating a new :py:class:`.Context` within an existing :py:class:`.Context`
53-
sets all parameters from scratch; it does not inherit properties from the
54-
parent :py:class:`.Context`. To achieve this, use::
55-
56-
with Context(**some_properties) as outer_ctx:
57-
with Context(outer_ctx, **other) as inner_ctx:
58-
...
59-
60-
The `inner_ctx` will have the same properties as `outer_ctx`, except for any
61-
properties modified in the `dataclasses.replace` call.
62-
63-
NOTE: The context is not shared across threads. In other words, the whole
64-
context block must be executed in the same thread. The following example will
65-
not work as expected::
66-
67-
executor = ThreadPoolExecutor()
68-
with ocp.Context(...): # Thread #1 creates Context A.
69-
executor.submit(ocp.save_pytree, ...) # Thread #2 sees "default" Context.
51+
sets all parameters from scratch by default. To inherit properties from a
52+
parent :py:class:`.Context`, you must explicitly pass the parent context as
53+
the first argument. The new context will inherit the parent's properties,
54+
except for any options explicitly provided as keyword arguments to the child
55+
context.
56+
57+
WARNING: The context is thread-local and is not shared across threads. The
58+
entire context block must be executed within the same thread. If you dispatch
59+
a checkpointing operation to a worker thread (e.g., via `ThreadPoolExecutor`),
60+
that thread will not inherit the context and will fall back to default
61+
settings.
62+
63+
Example:
64+
Basic usage and explicit inheritance::
65+
66+
import orbax.checkpoint as ocp
67+
68+
# Basic usage
69+
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()):
70+
ocp.save_pytree(directory, tree)
71+
72+
# Inheriting properties from an existing context
73+
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx:
74+
# inner_ctx inherits pytree_options, but overrides/adds array_options
75+
with ocp.Context(outer_ctx,
76+
array_options=ocp.options.ArrayOptions()
77+
) as inner_ctx:
78+
ocp.save_pytree(directory, tree)
79+
80+
Context is not shared across threads::
81+
82+
from concurrent.futures import ThreadPoolExecutor
83+
import orbax.checkpoint as ocp
84+
85+
executor = ThreadPoolExecutor(max_workers=1)
86+
with ocp.Context(
87+
pytree_options=ocp.options.PyTreeOptions()
88+
): # Thread #1 creates Context.
89+
# The following save_pytree call is executed in Thread #2, which sees
90+
# a "default" Context, NOT the one created above.
91+
executor.submit(ocp.save_pytree, directory, tree)
7092
7193
7294
Attributes:

0 commit comments

Comments
 (0)