diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 79ed9b707..e323a891c 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -42,7 +42,6 @@ import attr import jax -import jax.flatten_util import numpy as np from absl import logging from jax import numpy as jnp @@ -54,6 +53,7 @@ from jax.experimental import mesh_utils, multihost_utils from jax.extend.core import Primitive from jax.sharding import PartitionSpec +from jax.tree_util import default_registry from axlearn.common import serialization from axlearn.common.config import ( @@ -1853,35 +1853,38 @@ def thread_stack_traces() -> Sequence[Sequence[str]]: return grouped_lines -def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: +def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]: """Generate the (key, value) pairs for the immediate children of a pytree `node`. The returned children match those returned by `jax.tree_util.default_registry.flatten_one_level()`. - Reference: jax._src.tree_util.generate_key_paths() - Example: ``` assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] ``` """ - # pylint: disable-next=protected-access - registry_with_keypaths = jax._src.tree_util._registry_with_keypaths - - key_handler = registry_with_keypaths.get(type(node)) - if key_handler: - key_children, _ = key_handler.flatten_with_keys(node) - return key_children + # If node is a NamedTuple + if isinstance(node, tuple) and hasattr(node, "_fields"): + return [(jax.tree_util.GetAttrKey(name), getattr(node, name)) for name in node._fields] + + # If node is not a NT but exposes a public `_fields` attribute + if hasattr(node, "_fields") and not isinstance(node, tuple): + return [(jax.tree_util.GetAttrKey(name), getattr(node, name)) for name in node._fields] + # Standard JAX + try: + key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node) + if key_child_pairs: + return list(key_child_pairs) + except (ValueError, TypeError): + pass + # Node is Sequence flat = jax.tree_util.default_registry.flatten_one_level(node) if flat is None: return [] - if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node): - # Handle namedtuple as a special case, based on heuristic. - return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields] - return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])] + return [(jax.tree_util.FlattenedIndexKey(i), child) for i, child in enumerate(flat[0])] def find_cycles(tree: Nested) -> dict[str, KeyPath]: diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 62330d786..1917e99a3 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -43,8 +43,6 @@ def setup( """ # Use a GSPMD-friendly PRNG implementation. jax.config.update("jax_default_prng_impl", "rbg") - # This allows replicated jax.Arrays to be used for computation on the host. - jax.config.update("jax_spmd_mode", "allow_all") global _jax_distributed_initialized # pylint: disable=global-statement if not _jax_distributed_initialized: diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 2411759c3..0b833906f 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -334,6 +334,25 @@ class TestUnstructured: [(jax.tree_util.FlattenedIndexKey(k), v) for k, v in enumerate(original_tree.values())], ) + # eg OutputCollection(summaries={}, state_updates={}, module_outputs={})) + class CustomWithFields: + _fields = ("a", "b", "c") + + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + tree = CustomWithFields(**original_tree) + self.assertSequenceEqual( + pytree_children(tree), + [(jax.tree_util.GetAttrKey(k), getattr(tree, k)) for k in CustomWithFields._fields], + ) + + # Test object() + obj = object() + self.assertSequenceEqual(pytree_children(obj), []) + # No children self.assertSequenceEqual(pytree_children([]), [])