Resetting environment state in PPO training_step
function
#591
nic-barbara
started this conversation in
General
Replies: 1 comment
-
Hi @nic-barbara It looks like you need to reshape key_envs to correspond to the correct batch size (see here). |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi there,
I'm currently working with a recurrent PPO version of the PPO training algorithm in brax. For recurrent PPO, it's best to evaluate each training step (within a training epoch) over a complete trajectory of the environment - i.e: starting from just after it was last terminated/reset up until it terminates again.
To test things out, I'm looking to reset the environment in the PPO
training_step
function. However, I'm a little confused as to how to do this efficiently and correctly. I've tried replacing these lines:with this:
where
reset_fn = jax.jit(jax.vmap(env.reset))
. Doing this means I end up with an error regarding array sizes (the256
comes from me usingnum_envs = 256
for training).I'm using my own versions of the
train.py
andacting.py
files which are slightly modified versions from braxv0.10.5
. I have made only minimal changes to those files so I doubt they are the cause of this issue. I can train a policy as normal without introducing the additional reset.Any help would be greatly appreciated, and happy to provide more info as required!
Beta Was this translation helpful? Give feedback.
All reactions