Skip to content

Fix PPO importance-sampling ratio bias with squashed-Gaussian policies#118

Open
jmtoepperwien wants to merge 2 commits into
automl:mainfrom
jmtoepperwien:fix/ppo-tanh-correction-old-logprobs
Open

Fix PPO importance-sampling ratio bias with squashed-Gaussian policies#118
jmtoepperwien wants to merge 2 commits into
automl:mainfrom
jmtoepperwien:fix/ppo-tanh-correction-old-logprobs

Conversation

@jmtoepperwien
Copy link
Copy Markdown
Contributor

Fix PPO importance-sampling ratio bias with squashed-Gaussian policies

Problem

When PPO uses a squashed-Gaussian policy (4-tuple output: action, z, mean, log_std), the importance-sampling ratio becomes biased:

  • During rollout, old_log_probs stored in the buffer are computed without the tanh change-of-variables correction
  • During the PPO update, new_log_probs include the correction
  • The correction does not cancel in the ratio:
    ratio = exp(log_π_new(a) - log_π_old(a))
          = exp([log N(z|new) - log_correction] - log N(z|old))
          = [N(z|new)/N(z|old)] × 1/(1 - tanh(z)²)
    
  • The spurious 1/(1 - tanh(z)²) factor explodes as actions approach ±1, causing gradient instability

Solution

Two commits:

  1. Always apply tanh correction at rollout time (97e8391)

    • Remove the sac= gate from sample_nondeterministic_logprobs
    • Apply the correction whenever the model outputs a 4-tuple (which is the canonical signal that squashing is active)
    • This ensures old_log_probs stored in the buffer are already corrected, so the correction cancels correctly in the PPO update
  2. Check tanh_squash attribute explicitly (5e59c37)

    • Add tanh_squash=True to SACModel as a class attribute (matching PPOModel)
    • Pass tanh_squash=getattr(self.model, "tanh_squash", False) to sample_nondeterministic_logprobs at each call site
    • Makes the intent explicit and allows future models to opt in/out cleanly

With squashed-Gaussian policies, old_log_probs stored during rollout
lacked the tanh correction while new_log_probs in the update included
it. The correction did not cancel in the IS ratio, introducing a
multiplicative bias that grows as actions approach ±1.

Fix: remove the sac= gate from sample_nondeterministic_logprobs and
always apply the correction for 4-tuple model output, which is the
canonical signal that tanh squashing is active.
Rather than inferring squashing from the output tuple shape, read the
model's tanh_squash attribute directly. SACModel gains tanh_squash=True
as a class attribute to match the existing PPOModel pattern.
sample_nondeterministic_logprobs now takes an explicit tanh_squash flag.
@jmtoepperwien jmtoepperwien requested a review from amsks May 28, 2026 09:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant