Skip to content

Commit 7a51a4d

Browse files
fidlejMctxDev
authored andcommitted
Use jax.tree_util to avoid deprecation warnings.
PiperOrigin-RevId: 463276750
1 parent e5d0074 commit 7a51a4d

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

mctx/_src/search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def expand(
217217
chex.assert_shape([parent_index, action, next_node_index], (batch_size,))
218218

219219
# Retrieve states for nodes to be evaluated.
220-
embedding = jax.tree_map(
220+
embedding = jax.tree_util.tree_map(
221221
lambda x: x[batch_range, parent_index], tree.embeddings)
222222

223223
# Evaluate and create a new node.
@@ -333,7 +333,7 @@ def update_tree_node(
333333
tree.node_values, value, node_index),
334334
node_visits=batch_update(
335335
tree.node_visits, new_visit, node_index),
336-
embeddings=jax.tree_map(
336+
embeddings=jax.tree_util.tree_map(
337337
lambda t, s: batch_update(t, s, node_index),
338338
tree.embeddings, embedding))
339339

@@ -373,7 +373,7 @@ def _zeros(x):
373373
children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
374374
children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
375375
children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
376-
embeddings=jax.tree_map(_zeros, root.embedding),
376+
embeddings=jax.tree_util.tree_map(_zeros, root.embedding),
377377
root_invalid_actions=root_invalid_actions,
378378
extra_data=extra_data)
379379

mctx/_src/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def infer_batch_size(tree: Tree) -> int:
117117
"""Recovers batch size from `Tree` data structure."""
118118
if tree.node_values.ndim != 2:
119119
raise ValueError("Input tree is not batched.")
120-
chex.assert_equal_shape_prefix(jax.tree_leaves(tree), 1)
120+
chex.assert_equal_shape_prefix(jax.tree_util.tree_leaves(tree), 1)
121121
return tree.node_values.shape[0]
122122

123123

0 commit comments

Comments
 (0)