Skip to content

[Jax API update] Remove jax_spmd_mode #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,26 @@
from typing import (
Any,
Callable,
List,
Literal,
NamedTuple,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
runtime_checkable,
)

import attr
import jax
import jax.flatten_util
import numpy as np
from absl import logging
from jax import numpy as jnp
from jax._src.ad_checkpoint import name_p
from jax._src.lax import lax as lax_internal
from jax._src.mesh import thread_resources
from jax._src.tree_util import KeyEntry, KeyPath
from jax._src.tree_util import KeyEntry, KeyPath, flatten_one_level_with_keys
from jax.ad_checkpoint import Offloadable, Recompute, Saveable
from jax.experimental import mesh_utils, multihost_utils
from jax.extend.core import Primitive
Expand Down Expand Up @@ -1853,7 +1854,7 @@ def thread_stack_traces() -> Sequence[Sequence[str]]:
return grouped_lines


def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
def pytree_children(node: Any) -> List[Tuple[KeyEntry, Any]]:
"""Generate the (key, value) pairs for the immediate children of a pytree `node`.

The returned children match those returned by
Expand All @@ -1866,23 +1867,13 @@ def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])]
```
"""
# pylint: disable-next=protected-access
registry_with_keypaths = jax._src.tree_util._registry_with_keypaths

key_handler = registry_with_keypaths.get(type(node))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(node)
return key_children

flat = jax.tree_util.default_registry.flatten_one_level(node)
if flat is None:
try:
# pylint: disable-next=protected-access
key_child_pairs, _ = flatten_one_level_with_keys(node)
return list(key_child_pairs)
except ValueError:
return []

if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node):
# Handle namedtuple as a special case, based on heuristic.
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields]
return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]


def find_cycles(tree: Nested) -> dict[str, KeyPath]:
"""Find a cycle in pytree `tree` if one exists.
Expand Down
2 changes: 0 additions & 2 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def setup(
"""
# Use a GSPMD-friendly PRNG implementation.
jax.config.update("jax_default_prng_impl", "rbg")
# This allows replicated jax.Arrays to be used for computation on the host.
jax.config.update("jax_spmd_mode", "allow_all")

global _jax_distributed_initialized # pylint: disable=global-statement
if not _jax_distributed_initialized:
Expand Down