Open
Description
Issue: Inconsistent Hidden State Handling in act()
Method of PPO Implementation
Issue Summary
The act()
method in the PPO implementation does not pass hidden states to theself.actor_critic.act(obs)
and self.actor_critic.evaluate(critic_obs)
, leading to inconsistent action/value estimates between rollout (inference) and training. This issue is especially problematic for recurrent policies (e.g., LSTM/GRU), where past information should influence both action selection and value estimation.
Suggested Fix
Modify act() to include hidden states when FFing the actor and critic:
def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# ---------------------------
# ✅ Fix: Pass hidden states
self.transition.actions = self.actor_critic.act(obs, hidden_states=self.transition.hidden_states[0]).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs, hidden_states=self.transition.hidden_states[1]).detach()
# ---------------------------
else:
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
Metadata
Metadata
Assignees
Labels
No labels