@@ -44,29 +44,51 @@ def get_context(default: Context | None = None) -> Context:
4444class Context (epy .ContextManager ):
4545 """Context for customized checkpointing.
4646
47- Usage example::
48-
49- with ocp.Context(...):
50- ocp.save_pytree(...)
47+ This class manages the configuration options (e.g., async, multiprocessing,
48+ array handling) used during Orbax checkpoint operations.
5149
5250 Creating a new :py:class:`.Context` within an existing :py:class:`.Context`
53- sets all parameters from scratch; it does not inherit properties from the
54- parent :py:class:`.Context`. To achieve this, use::
55-
56- with Context(**some_properties) as outer_ctx:
57- with Context(outer_ctx, **other) as inner_ctx:
58- ...
59-
60- The `inner_ctx` will have the same properties as `outer_ctx`, except for any
61- properties modified in the `dataclasses.replace` call.
62-
63- NOTE: The context is not shared across threads. In other words, the whole
64- context block must be executed in the same thread. The following example will
65- not work as expected::
66-
67- executor = ThreadPoolExecutor()
68- with ocp.Context(...): # Thread #1 creates Context A.
69- executor.submit(ocp.save_pytree, ...) # Thread #2 sees "default" Context.
51+ sets all parameters from scratch by default. To inherit properties from a
52+ parent :py:class:`.Context`, you must explicitly pass the parent context as
53+ the first argument. The new context will inherit the parent's properties,
54+ except for any options explicitly provided as keyword arguments to the child
55+ context.
56+
57+ WARNING: The context is thread-local and is not shared across threads. The
58+ entire context block must be executed within the same thread. If you dispatch
59+ a checkpointing operation to a worker thread (e.g., via `ThreadPoolExecutor`),
60+ that thread will not inherit the context and will fall back to default
61+ settings.
62+
63+ Example:
64+ Basic usage and explicit inheritance::
65+
66+ import orbax.checkpoint as ocp
67+
68+ # Basic usage
69+ with ocp.Context(pytree_options=ocp.options.PyTreeOptions()):
70+ ocp.save_pytree(directory, tree)
71+
72+ # Inheriting properties from an existing context
73+ with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx:
74+ # inner_ctx inherits pytree_options, but overrides/adds array_options
75+ with ocp.Context(outer_ctx,
76+ array_options=ocp.options.ArrayOptions()
77+ ) as inner_ctx:
78+ ocp.save_pytree(directory, tree)
79+
80+ Context is not shared across threads::
81+
82+ from concurrent.futures import ThreadPoolExecutor
83+ import orbax.checkpoint as ocp
84+
85+ executor = ThreadPoolExecutor(max_workers=1)
86+ with ocp.Context(
87+ pytree_options=ocp.options.PyTreeOptions()
88+ ): # Thread #1 creates Context.
89+ # The following save_pytree call is executed in Thread #2, which sees
90+ # a "default" Context, NOT the one created above.
91+ executor.submit(ocp.save_pytree, directory, tree)
7092
7193
7294 Attributes:
0 commit comments