Skip to content

Commit a8ea1be

Browse files
committed
Change deprecated jax.tree_util.tree_map to jax.tree.map. Fix argument passed to jax.numpy.finfo call.
1 parent 9fb7339 commit a8ea1be

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

mctx/_src/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _mask_invalid_actions(logits, invalid_actions):
385385

386386

387387
def _get_logits_from_probs(probs):
388-
tiny = jnp.finfo(probs).tiny
388+
tiny = jnp.finfo(probs.dtype).tiny
389389
return jnp.log(jnp.maximum(probs, tiny))
390390

391391

mctx/_src/search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def expand(
219219
chex.assert_shape([parent_index, action, next_node_index], (batch_size,))
220220

221221
# Retrieve states for nodes to be evaluated.
222-
embedding = jax.tree_util.tree_map(
222+
embedding = jax.tree.map(
223223
lambda x: x[batch_range, parent_index], tree.embeddings)
224224

225225
# Evaluate and create a new node.
@@ -335,7 +335,7 @@ def update_tree_node(
335335
tree.node_values, value, node_index),
336336
node_visits=batch_update(
337337
tree.node_visits, new_visit, node_index),
338-
embeddings=jax.tree_util.tree_map(
338+
embeddings=jax.tree.map(
339339
lambda t, s: batch_update(t, s, node_index),
340340
tree.embeddings, embedding))
341341

@@ -375,7 +375,7 @@ def _zeros(x):
375375
children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
376376
children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
377377
children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
378-
embeddings=jax.tree_util.tree_map(_zeros, root.embedding),
378+
embeddings=jax.tree.map(_zeros, root.embedding),
379379
root_invalid_actions=root_invalid_actions,
380380
extra_data=extra_data)
381381

mctx/_src/tests/policies_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_gumbel_muzero_policy(self):
245245

246246
# Testing max_depth.
247247
leaf, max_found_depth = _get_deepest_leaf(
248-
jax.tree_util.tree_map(lambda x: x[0], policy_output.search_tree),
248+
jax.tree.map(lambda x: x[0], policy_output.search_tree),
249249
policy_output.search_tree.ROOT_INDEX)
250250
self.assertEqual(max_depth, max_found_depth)
251251
self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])

0 commit comments

Comments
 (0)