Skip to content

Commit 7cb92c0

Browse files
author
MctxDev
committed
Merge pull request #72 from carlosgmartin:stochastic_muzero_arguments
PiperOrigin-RevId: 571906463
2 parents 545b8ee + b2e09ad commit 7cb92c0

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

mctx/_src/policies.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,6 @@ def stochastic_muzero_policy(
238238
decision_recurrent_fn: base.DecisionRecurrentFn,
239239
chance_recurrent_fn: base.ChanceRecurrentFn,
240240
num_simulations: int,
241-
num_actions: int,
242-
num_chance_outcomes: int,
243241
invalid_actions: Optional[chex.Array] = None,
244242
max_depth: Optional[int] = None,
245243
loop_fn: base.LoopFn = jax.lax.fori_loop,
@@ -271,8 +269,6 @@ def stochastic_muzero_policy(
271269
`(params, rng_key, chance_outcome, afterstate_embedding)` and returns a
272270
`(ChanceRecurrentFnOutput, state_embedding)`.
273271
num_simulations: the number of simulations.
274-
num_actions: number of environment actions.
275-
num_chance_outcomes: number of chance outcomes following an afterstate.
276272
invalid_actions: a mask with invalid actions. Invalid actions have ones,
277273
valid actions have zeros in the mask. Shape `[B, num_actions]`.
278274
max_depth: maximum search tree depth allowed during simulation.
@@ -293,6 +289,8 @@ def stochastic_muzero_policy(
293289
search tree.
294290
"""
295291

292+
num_actions = root.prior_logits.shape[-1]
293+
296294
rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
297295

298296
# Adding Dirichlet noise.
@@ -309,9 +307,9 @@ def stochastic_muzero_policy(
309307
# construct a dummy afterstate embedding
310308
batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0]
311309
dummy_action = jnp.zeros([batch_size], dtype=jnp.int32)
312-
_, dummy_afterstate_embedding = decision_recurrent_fn(params, rng_key,
313-
dummy_action,
314-
root.embedding)
310+
dummy_output, dummy_afterstate_embedding = decision_recurrent_fn(
311+
params, rng_key, dummy_action, root.embedding)
312+
num_chance_outcomes = dummy_output.chance_logits.shape[-1]
315313

316314
root = root.replace(
317315
# pad action logits with num_chance_outcomes so dim is A + C

mctx/_src/tests/policies_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,6 @@ def test_stochastic_muzero_policy(self):
347347
decision_recurrent_fn=decision_rec_fn,
348348
chance_recurrent_fn=chance_rec_fn,
349349
num_simulations=2 * num_simulations,
350-
num_actions=4,
351-
num_chance_outcomes=num_chance_outcomes,
352350
invalid_actions=invalid_actions,
353351
dirichlet_fraction=0.0)
354352

0 commit comments

Comments
 (0)