Skip to content

Commit b67c440

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 873037681
1 parent 306242b commit b67c440

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,11 @@ class TrainState:
10351035
serialized_item, value_metadata_tree
10361036
)
10371037
else:
1038+
# Deserialize value metadata tree to the same structure as item to allow
1039+
# for comparison with item that contains rich types.
1040+
value_metadata_tree = tree_utils.deserialize_tree(
1041+
value_metadata_tree, item
1042+
)
10381043
# is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
10391044
# dicts, lists, custom nodes) as leaves, as they do not contain any
10401045
# actual data to be restored, but are needed to maintain the structure.
@@ -1060,13 +1065,7 @@ class TrainState:
10601065
restore_args = tree_metadata.serialize_tree(
10611066
restore_args, self._pytree_metadata_options
10621067
)
1063-
1064-
value_metadata_tree_deserialized = tree_utils.deserialize_tree(
1065-
value_metadata_tree, item
1066-
)
1067-
restore_args_deserialized = tree_utils.deserialize_tree(restore_args, item)
1068-
value_metadata_tree = value_metadata_tree_deserialized
1069-
restore_args = restore_args_deserialized
1068+
restore_args = tree_utils.deserialize_tree(restore_args, item)
10701069

10711070
param_infos = self._get_param_infos(
10721071
item=value_metadata_tree,

checkpoint/orbax/checkpoint/_src/tree/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar, Union
1818

19+
import flax
1920
import jax
2021
import jax.tree_util as jtu
2122
from orbax.checkpoint._src.arrays import abstract_arrays
@@ -235,10 +236,14 @@ def _reconstruct_from_keypath(keypath, _):
235236
result = serialized
236237
for key in keypath:
237238
key_name = get_key_name(key)
238-
if isinstance(key, jax.tree_util.GetAttrKey) and isinstance_of_namedtuple(
239-
result
240-
):
241-
result = getattr(result, key_name)
239+
if isinstance(key, jax.tree_util.GetAttrKey):
240+
if isinstance_of_namedtuple(result):
241+
result = getattr(result, key_name)
242+
elif isinstance(result, flax.struct.PyTreeNode):
243+
# Special case to support flax.struct.PyTreeNode
244+
result = result.__dict__[key_name]
245+
else:
246+
result = result[key_name]
242247
else:
243248
# Special case to support Pax.
244249
if not isinstance(result, (list, tuple)) and key_name not in result:

checkpoint/orbax/checkpoint/checkpoint_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def construct_restore_args(
454454
sharding_tree: Optional[PyTree] = None,
455455
set_global_shape: bool = True,
456456
support_layout: bool = False,
457+
allow_uneven_sharding: bool = False,
457458
) -> PyTree:
458459
"""Creates restore_args given a target PyTree.
459460
@@ -496,6 +497,7 @@ def construct_restore_args(
496497
set_global_shape: If true, set the `global_shape` field of ArrayRestoreArgs.
497498
support_layout: If true, layout is extracted from jax.Array or
498499
jax.ShapeDtypeStruct.
500+
allow_uneven_sharding: If true, allow padding/slicing for uneven sharding.
499501
500502
Returns:
501503
A PyTree matching target of RestoreArgs (or ArrayRestoreArgs) objects.
@@ -517,6 +519,7 @@ def _array_restore_args(
517519
sharding=sharding,
518520
global_shape=global_shape,
519521
dtype=dtype,
522+
strict=not allow_uneven_sharding,
520523
)
521524

522525
def _restore_args(

0 commit comments

Comments
 (0)