Skip to content

[Jax API update] Remove jax_spmd_mode #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import attr
import jax
import jax.flatten_util
import numpy as np
from absl import logging
from jax import numpy as jnp
Expand All @@ -54,6 +53,7 @@
from jax.experimental import mesh_utils, multihost_utils
from jax.extend.core import Primitive
from jax.sharding import PartitionSpec
from jax.tree_util import default_registry

from axlearn.common import serialization
from axlearn.common.config import (
Expand Down Expand Up @@ -1853,36 +1853,23 @@ def thread_stack_traces() -> Sequence[Sequence[str]]:
return grouped_lines


def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
def pytree_children(node: Any) -> list[tuple[KeyEntry, Any]]:
Copy link
Contributor

@apghml apghml May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this example from the jax docs fail with the new implementation?

import jax.numpy as jnp
import jax.tree
from jax.tree_util import GetAttrKey


class MyContainer:

    def __init__(self):

        self.x = jnp.zeros(2)
        self.y = jnp.ones(2)



def flatten_with_keys(obj):

    children = [(GetAttrKey('x'), obj.x),
                (GetAttrKey('y'), obj.y)]



    aux_data = ()  # aux_data must contain static, hashable data.

    return children, aux_data

def unflatten(aux_data, children):
    obj = object.__new__(MyContainer)

    obj.x, obj.y = children

    obj.size, = aux_data

    return obj

jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)

pytree_children(MyContainer())

"""Generate the (key, value) pairs for the immediate children of a pytree `node`.

The returned children match those returned by
`jax.tree_util.default_registry.flatten_one_level()`.

Reference: jax._src.tree_util.generate_key_paths()

Example:
```
assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])]
```
"""
# pylint: disable-next=protected-access
registry_with_keypaths = jax._src.tree_util._registry_with_keypaths

key_handler = registry_with_keypaths.get(type(node))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(node)
return key_children

flat = jax.tree_util.default_registry.flatten_one_level(node)
if flat is None:
try:
key_child_pairs, _ = default_registry.flatten_one_level_with_keys(node)
return list(key_child_pairs)
except ValueError:
return []

if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node):
# Handle namedtuple as a special case, based on heuristic.
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields]
return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]


def find_cycles(tree: Nested) -> dict[str, KeyPath]:
"""Find a cycle in pytree `tree` if one exists.
Expand Down
2 changes: 0 additions & 2 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def setup(
"""
# Use a GSPMD-friendly PRNG implementation.
jax.config.update("jax_default_prng_impl", "rbg")
# This allows replicated jax.Arrays to be used for computation on the host.
jax.config.update("jax_spmd_mode", "allow_all")

global _jax_distributed_initialized # pylint: disable=global-statement
if not _jax_distributed_initialized:
Expand Down