Skip to content

Commit 306242b

Browse files
author
Orbax Authors
committed
Improve SerializationParam and DeserializationParam class docstrings
PiperOrigin-RevId: 874737798
1 parent 948fc41 commit 306242b

File tree

1 file changed

+75
-9
lines changed
  • checkpoint/orbax/checkpoint/experimental/v1/_src/serialization

1 file changed

+75
-9
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Define types for :py:class:`.LeafHandler`."""
1616

1717
import dataclasses
18-
import typing
1918
from typing import Any, Awaitable, Generic, Protocol, Sequence, Tuple, Type, TypeVar
2019

2120
import jax
@@ -96,19 +95,52 @@ def is_placeholder(value: Any) -> bool:
9695
return value is PLACEHOLDER
9796

9897

99-
@typing.final
100-
@dataclasses.dataclass(frozen=True, kw_only=True)
98+
@dataclasses.dataclass
10199
class SerializationParam(Generic[Leaf]):
100+
"""Represents a specific leaf-level parameter within a PyTree.
101+
102+
SerializationParam represents a single PyTree leaf by pairing its value
103+
(data or metadata) with its keypath (the location within the original nested
104+
structure). It serves as a container for the parameters passed to
105+
`LeafHandler`, which enables the implementation of serialization support for
106+
custom leaf objects.
107+
108+
Example Usage:
109+
SerializationParam is used when implementing custom LeafHandlers::
110+
111+
class MyCustomHandler(LeafHandler):
112+
async def serialize(
113+
self,
114+
params: Sequence[SerializationParam],
115+
context: SerializationContext
116+
):
117+
for param in params:
118+
# param.value contains the object to be serialized.
119+
data = param.value
120+
121+
# Derive the name from the keypath.
122+
leaf_name = "/".join(str(k) for k in param.keypath)
123+
124+
# Using the context to determine the save location
125+
print(f"Saving {leaf_name} to {context.parent_dir}")
126+
127+
Attributes:
128+
keypath (Tuple[Any, ...]): A tuple of keys (often JAX Key objects) that
129+
defines the path from the root of the PyTree to this specific leaf.
130+
value (Any): The data associated with the leaf. This could be a jax.Array,
131+
a numpy.ndarray, or a metadata object depending on the stage of
132+
the checkpointing process.
133+
"""
102134
keypath: tree_types.PyTreeKeyPath
103135
value: Leaf
104136

105137
@property
106138
def name(self) -> str:
139+
"""The name of the parameter derived from its keypath."""
107140
return tree_utils.param_name_from_keypath(self.keypath)
108141

109142

110-
@typing.final
111-
@dataclasses.dataclass(frozen=True, kw_only=True)
143+
@dataclasses.dataclass
112144
class SerializationContext:
113145
"""A container for the execution context passed to :py:class:`LeafHandler`.
114146
@@ -147,19 +179,53 @@ async def serialize(
147179
byte_limiter: limits.LimitInFlightBytes | None = None
148180

149181

150-
@typing.final
151-
@dataclasses.dataclass(frozen=True, kw_only=True)
182+
@dataclasses.dataclass
152183
class DeserializationParam(Generic[AbstractLeaf]):
184+
"""Represents a leaf-level entry for PyTree restoration.
185+
186+
DeserializationParam represents a single PyTree leaf during the restoration
187+
process by pairing its keypath (the location within the target structure)
188+
with an optional template value. It serves as a container for the parameters
189+
passed to `LeafHandler`, which enables the implementation of deserialization
190+
support for custom leaf objects.
191+
192+
Example Usage:
193+
DeserializationParam is utilized when implementing custom LeafHandlers::
194+
195+
class MyCustomHandler(LeafHandler):
196+
async def deserialize(
197+
self,
198+
params: Sequence[DeserializationParam],
199+
context: DeserializationContext
200+
):
201+
results = []
202+
for param in params:
203+
leaf_name = "/".join(str(k) for k in param.keypath)
204+
205+
# The param argument provides the template metadata.
206+
if param.value is not None:
207+
print(f"Restoring {leaf_name} with shape: {param.value.shape}")
208+
209+
# results.append(restored_object)
210+
return results
211+
212+
Attributes:
213+
keypath (Tuple[Any, ...]): A tuple of keys defining the path from the
214+
PyTree root to this specific leaf.
215+
value (Optional[Any]): An optional template object (such as a
216+
jax.ShapeDtypeStruct or an existing array) that provides the
217+
metadata necessary to correctly instantiate the restored leaf.
218+
"""
153219
keypath: tree_types.PyTreeKeyPath
154220
value: AbstractLeaf | Type[AbstractLeaf] | None = None
155221

156222
@property
157223
def name(self) -> str:
224+
"""The name of the parameter derived from its keypath."""
158225
return tree_utils.param_name_from_keypath(self.keypath)
159226

160227

161-
@typing.final
162-
@dataclasses.dataclass(frozen=True, kw_only=True)
228+
@dataclasses.dataclass
163229
class DeserializationContext:
164230
"""A container for the execution context passed to :py:class:`LeafHandler`.
165231

0 commit comments

Comments
 (0)