@@ -238,8 +238,6 @@ def stochastic_muzero_policy(
238238 decision_recurrent_fn : base .DecisionRecurrentFn ,
239239 chance_recurrent_fn : base .ChanceRecurrentFn ,
240240 num_simulations : int ,
241- num_actions : int ,
242- num_chance_outcomes : int ,
243241 invalid_actions : Optional [chex .Array ] = None ,
244242 max_depth : Optional [int ] = None ,
245243 loop_fn : base .LoopFn = jax .lax .fori_loop ,
@@ -271,8 +269,6 @@ def stochastic_muzero_policy(
271269 `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a
272270 `(ChanceRecurrentFnOutput, state_embedding)`.
273271 num_simulations: the number of simulations.
274- num_actions: number of environment actions.
275- num_chance_outcomes: number of chance outcomes following an afterstate.
276272 invalid_actions: a mask with invalid actions. Invalid actions have ones,
277273 valid actions have zeros in the mask. Shape `[B, num_actions]`.
278274 max_depth: maximum search tree depth allowed during simulation.
@@ -293,6 +289,8 @@ def stochastic_muzero_policy(
293289 search tree.
294290 """
295291
292+ num_actions = root .prior_logits .shape [- 1 ]
293+
296294 rng_key , dirichlet_rng_key , search_rng_key = jax .random .split (rng_key , 3 )
297295
298296 # Adding Dirichlet noise.
@@ -309,9 +307,9 @@ def stochastic_muzero_policy(
309307 # construct a dummy afterstate embedding
310308 batch_size = jax .tree_util .tree_leaves (root .embedding )[0 ].shape [0 ]
311309 dummy_action = jnp .zeros ([batch_size ], dtype = jnp .int32 )
312- _ , dummy_afterstate_embedding = decision_recurrent_fn (params , rng_key ,
313- dummy_action ,
314- root . embedding )
310+ dummy_output , dummy_afterstate_embedding = decision_recurrent_fn (
311+ params , rng_key , dummy_action , root . embedding )
312+ num_chance_outcomes = dummy_output . chance_logits . shape [ - 1 ]
315313
316314 root = root .replace (
317315 # pad action logits with num_chance_outcomes so dim is A + C
0 commit comments