@@ -219,7 +219,7 @@ def expand(
219219 chex .assert_shape ([parent_index , action , next_node_index ], (batch_size ,))
220220
221221 # Retrieve states for nodes to be evaluated.
222- embedding = jax .tree_util . tree_map (
222+ embedding = jax .tree . map (
223223 lambda x : x [batch_range , parent_index ], tree .embeddings )
224224
225225 # Evaluate and create a new node.
@@ -335,7 +335,7 @@ def update_tree_node(
335335 tree .node_values , value , node_index ),
336336 node_visits = batch_update (
337337 tree .node_visits , new_visit , node_index ),
338- embeddings = jax .tree_util . tree_map (
338+ embeddings = jax .tree . map (
339339 lambda t , s : batch_update (t , s , node_index ),
340340 tree .embeddings , embedding ))
341341
@@ -375,7 +375,7 @@ def _zeros(x):
375375 children_visits = jnp .zeros (batch_node_action , dtype = jnp .int32 ),
376376 children_rewards = jnp .zeros (batch_node_action , dtype = data_dtype ),
377377 children_discounts = jnp .zeros (batch_node_action , dtype = data_dtype ),
378- embeddings = jax .tree_util . tree_map (_zeros , root .embedding ),
378+ embeddings = jax .tree . map (_zeros , root .embedding ),
379379 root_invalid_actions = root_invalid_actions ,
380380 extra_data = extra_data )
381381
0 commit comments