|
15 | 15 | """Define types for :py:class:`.LeafHandler`.""" |
16 | 16 |
|
17 | 17 | import dataclasses |
18 | | -import typing |
19 | 18 | from typing import Any, Awaitable, Generic, Protocol, Sequence, Tuple, Type, TypeVar |
20 | 19 |
|
21 | 20 | import jax |
@@ -96,19 +95,52 @@ def is_placeholder(value: Any) -> bool: |
96 | 95 | return value is PLACEHOLDER |
97 | 96 |
|
98 | 97 |
|
99 | | -@typing.final |
100 | | -@dataclasses.dataclass(frozen=True, kw_only=True) |
| 98 | +@dataclasses.dataclass |
101 | 99 | class SerializationParam(Generic[Leaf]): |
| 100 | + """Represents a specific leaf-level parameter within a PyTree. |
| 101 | +
|
| 102 | + SerializationParam represents a single PyTree leaf by pairing its value |
| 103 | + (data or metadata) with its keypath (the location within the original nested |
| 104 | + structure). It serves as a container for the parameters passed to |
| 105 | + `LeafHandler`, which enables the implementation of serialization support for |
| 106 | + custom leaf objects. |
| 107 | +
|
| 108 | + Example Usage: |
| 109 | + SerializationParam is used when implementing custom LeafHandlers:: |
| 110 | +
|
| 111 | + class MyCustomHandler(LeafHandler): |
| 112 | + async def serialize( |
| 113 | + self, |
| 114 | + params: Sequence[SerializationParam], |
| 115 | + context: SerializationContext |
| 116 | + ): |
| 117 | + for param in params: |
| 118 | + # param.value contains the object to be serialized. |
| 119 | + data = param.value |
| 120 | +
|
| 121 | + # Derive the name from the keypath. |
| 122 | + leaf_name = "/".join(str(k) for k in param.keypath) |
| 123 | +
|
| 124 | + # Using the context to determine the save location |
| 125 | + print(f"Saving {leaf_name} to {context.parent_dir}") |
| 126 | +
|
| 127 | + Attributes: |
| 128 | + keypath (Tuple[Any, ...]): A tuple of keys (often JAX Key objects) that |
| 129 | + defines the path from the root of the PyTree to this specific leaf. |
| 130 | + value (Any): The data associated with the leaf. This could be a jax.Array, |
| 131 | + a numpy.ndarray, or a metadata object depending on the stage of |
| 132 | + the checkpointing process. |
| 133 | + """ |
102 | 134 | keypath: tree_types.PyTreeKeyPath |
103 | 135 | value: Leaf |
104 | 136 |
|
105 | 137 | @property |
106 | 138 | def name(self) -> str: |
| 139 | + """The name of the parameter derived from its keypath.""" |
107 | 140 | return tree_utils.param_name_from_keypath(self.keypath) |
108 | 141 |
|
109 | 142 |
|
110 | | -@typing.final |
111 | | -@dataclasses.dataclass(frozen=True, kw_only=True) |
| 143 | +@dataclasses.dataclass |
112 | 144 | class SerializationContext: |
113 | 145 | """A container for the execution context passed to :py:class:`LeafHandler`. |
114 | 146 |
|
@@ -147,19 +179,53 @@ async def serialize( |
147 | 179 | byte_limiter: limits.LimitInFlightBytes | None = None |
148 | 180 |
|
149 | 181 |
|
150 | | -@typing.final |
151 | | -@dataclasses.dataclass(frozen=True, kw_only=True) |
| 182 | +@dataclasses.dataclass |
152 | 183 | class DeserializationParam(Generic[AbstractLeaf]): |
| 184 | + """Represents a leaf-level entry for PyTree restoration. |
| 185 | +
|
| 186 | + DeserializationParam represents a single PyTree leaf during the restoration |
| 187 | + process by pairing its keypath (the location within the target structure) |
| 188 | + with an optional template value. It serves as a container for the parameters |
| 189 | + passed to `LeafHandler`, which enables the implementation of deserialization |
| 190 | + support for custom leaf objects. |
| 191 | +
|
| 192 | + Example Usage: |
| 193 | + DeserializationParam is utilized when implementing custom LeafHandlers:: |
| 194 | +
|
| 195 | + class MyCustomHandler(LeafHandler): |
| 196 | + async def deserialize( |
| 197 | + self, |
| 198 | + params: Sequence[DeserializationParam], |
| 199 | + context: DeserializationContext |
| 200 | + ): |
| 201 | + results = [] |
| 202 | + for param in params: |
| 203 | + leaf_name = "/".join(str(k) for k in param.keypath) |
| 204 | +
|
| 205 | + # The param argument provides the template metadata. |
| 206 | + if param.value is not None: |
| 207 | + print(f"Restoring {leaf_name} with shape: {param.value.shape}") |
| 208 | +
|
| 209 | + # results.append(restored_object) |
| 210 | + return results |
| 211 | +
|
| 212 | + Attributes: |
| 213 | + keypath (Tuple[Any, ...]): A tuple of keys defining the path from the |
| 214 | + PyTree root to this specific leaf. |
| 215 | + value (Optional[Any]): An optional template object (such as a |
| 216 | + jax.ShapeDtypeStruct or an existing array) that provides the |
| 217 | + metadata necessary to correctly instantiate the restored leaf. |
| 218 | + """ |
153 | 219 | keypath: tree_types.PyTreeKeyPath |
154 | 220 | value: AbstractLeaf | Type[AbstractLeaf] | None = None |
155 | 221 |
|
156 | 222 | @property |
157 | 223 | def name(self) -> str: |
| 224 | + """The name of the parameter derived from its keypath.""" |
158 | 225 | return tree_utils.param_name_from_keypath(self.keypath) |
159 | 226 |
|
160 | 227 |
|
161 | | -@typing.final |
162 | | -@dataclasses.dataclass(frozen=True, kw_only=True) |
| 228 | +@dataclasses.dataclass |
163 | 229 | class DeserializationContext: |
164 | 230 | """A container for the execution context passed to :py:class:`LeafHandler`. |
165 | 231 |
|
|
0 commit comments