Resetting environment state in PPO training_step
function
#591
nic-barbara
started this conversation in
General
Replies: 1 comment
-
Hi @nic-barbara It looks like you need to reshape key_envs to correspond to the correct batch size (see here). |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there,
I'm currently working with a recurrent PPO version of the PPO training algorithm in brax. For recurrent PPO, it's best to evaluate each training step (within a training epoch) over a complete trajectory of the environment - i.e: starting from just after it was last terminated/reset up until it terminates again.
To test things out, I'm looking to reset the environment in the PPO
training_step
function. However, I'm a little confused as to how to do this efficiently and correctly. I've tried replacing these lines:with this:
where
reset_fn = jax.jit(jax.vmap(env.reset))
. Doing this means I end up with an error regarding array sizes (the256
comes from me usingnum_envs = 256
for training).I'm using my own versions of the
train.py
andacting.py
files which are slightly modified versions from braxv0.10.5
. I have made only minimal changes to those files so I doubt they are the cause of this issue. I can train a policy as normal without introducing the additional reset.Any help would be greatly appreciated, and happy to provide more info as required!
Beta Was this translation helpful? Give feedback.
All reactions