@@ -217,7 +217,7 @@ def expand(
217217 chex .assert_shape ([parent_index , action , next_node_index ], (batch_size ,))
218218
219219 # Retrieve states for nodes to be evaluated.
220- embedding = jax .tree_map (
220+ embedding = jax .tree_util . tree_map (
221221 lambda x : x [batch_range , parent_index ], tree .embeddings )
222222
223223 # Evaluate and create a new node.
@@ -333,7 +333,7 @@ def update_tree_node(
333333 tree .node_values , value , node_index ),
334334 node_visits = batch_update (
335335 tree .node_visits , new_visit , node_index ),
336- embeddings = jax .tree_map (
336+ embeddings = jax .tree_util . tree_map (
337337 lambda t , s : batch_update (t , s , node_index ),
338338 tree .embeddings , embedding ))
339339
@@ -373,7 +373,7 @@ def _zeros(x):
373373 children_visits = jnp .zeros (batch_node_action , dtype = jnp .int32 ),
374374 children_rewards = jnp .zeros (batch_node_action , dtype = data_dtype ),
375375 children_discounts = jnp .zeros (batch_node_action , dtype = data_dtype ),
376- embeddings = jax .tree_map (_zeros , root .embedding ),
376+ embeddings = jax .tree_util . tree_map (_zeros , root .embedding ),
377377 root_invalid_actions = root_invalid_actions ,
378378 extra_data = extra_data )
379379
0 commit comments