Skip to content

Commit f8acd28

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Replace deprecated jax.tree_* functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 633689990
1 parent 5e27e61 commit f8acd28

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

language_table/train/policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, time_step_spec, action_spec, model, checkpoint_path,
6767

6868
def _run_action_inference(self, observation):
6969
# Add a batch dim.
70-
observation = jax.tree_map(lambda x: jnp.expand_dims(x, 0), observation)
70+
observation = jax.tree.map(lambda x: jnp.expand_dims(x, 0), observation)
7171

7272
normalized_action = self.model.apply(
7373
self.variables, observation, train=False)

language_table/train/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _body_fun(step, state_and_metrics):
4444
train_rng = jax.random.fold_in(train_rng, jax.lax.axis_index("batch"))
4545
new_state, metrics_update = agent.train(
4646
state=state,
47-
batch=jax.tree_map(lambda x: x[step], batches),
47+
batch=jax.tree.map(lambda x: x[step], batches),
4848
rng=train_rng,
4949
)
5050
return new_state, metrics_update
@@ -101,7 +101,7 @@ def train(
101101
)
102102
train_iter = train_ds.as_numpy_iterator()
103103
rng, agent_rng = jax.random.split(rng)
104-
sample_batch = jax.tree_map(lambda x: x[0][0], next(train_iter))
104+
sample_batch = jax.tree.map(lambda x: x[0][0], next(train_iter))
105105
agent = create_agent(
106106
config.agent_name,
107107
config.model_name,
@@ -257,7 +257,7 @@ def create_agent(
257257

258258
def merge_batch_stats(replicated_state):
259259
"""Merge model batch stats."""
260-
if jax.tree_leaves(replicated_state.batch_stats):
260+
if jax.tree.leaves(replicated_state.batch_stats):
261261
cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, "x"), "x")
262262
return replicated_state.replace(
263263
batch_stats=cross_replica_mean(replicated_state.batch_stats)

0 commit comments

Comments
 (0)