Skip to content

Commit bc89134

Browse files
Fix meta ppo rng split (#43)
1 parent a46e78c commit bc89134

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

training/train_meta_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def _meta_step(meta_state, _):
159159

160160
# INIT ENV
161161
rng, _rng1, _rng2 = jax.random.split(rng, num=3)
162-
ruleset_rng = jax.random.split(rng, num=config.num_envs_per_device)
163-
reset_rng = jax.random.split(rng, num=config.num_envs_per_device)
162+
ruleset_rng = jax.random.split(_rng1, num=config.num_envs_per_device)
163+
reset_rng = jax.random.split(_rng2, num=config.num_envs_per_device)
164164

165165
# sample rulesets for this meta update
166166
rulesets = jax.vmap(benchmark.sample_ruleset)(ruleset_rng)

0 commit comments

Comments
 (0)