Skip to content

Updated jax.tree_map to jax.tree_util.tree_map in update_transformation.py #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions axlearn/common/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from unittest import mock

import jax.ad_checkpoint
import jax.core
import jax.interpreters.ad
import jax.extend.core
import jax.extend.ad
import jax.random
import numpy as np
from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -130,7 +130,7 @@ def backward_impl(x):
prim = jax.extend.core.Primitive("passthrough_with_callback")
prim.def_impl(forward_impl)
prim.def_abstract_eval(forward_impl)
jax.interpreters.ad.deflinear(prim, backward_impl)
jax.extend.ad.deflinear(prim, backward_impl)
return prim.bind


Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.interpreters.pxla import thread_resources
from jax.extend.pxla import thread_resources
from jax.sharding import PartitionSpec

from axlearn.common.attention import Dropout, ForwardMode, GroupedQueryAttention
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/update_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def real_transform(_):
return new_updates.delta_updates, new_state

def stop_transform(_):
return jax.tree_map(jnp.zeros_like, updates.delta_updates), prev_state
return jax.tree_util.tree_map(jnp.zeros_like, updates.delta_updates), prev_state

# We do the computation regardless of the should_update value, so we could have
# equally used jnp.where() here instead.
Expand Down
Loading