A JAX implementation of policy gradient methods with stability enhancements for CartPole-v1. Features baseline subtraction, entropy regularization, gradient clipping, and per-episode normalization for robust training.
This implementation demonstrates how various enhancements to vanilla policy gradient methods improve training stability and performance. The key improvements include:
- Baseline subtraction for variance reduction
- Entropy regularization for exploration
- Gradient clipping for stability
- Per-episode normalization for fair comparison
(venv) ➜ jax-vpg git:(master) ✗ python3 main.py --compare --episodes=2000 --episodes-per-update=10
🎯 Starting VPG Variant Comparison...
Running 4 variants with 2000 episodes each
🚀 Running VPG (Basic)...
Description: Basic VPG with no enhancements
Episode 0-9: Avg Reward: 17.60, Loss: 1144.4250
Episode 10-19: Avg Reward: 22.90, Loss: 2510.6296
...
✅ VPG (Basic) completed successfully
🚀 Running With Baseline...
Description: VPG with baseline subtraction for variance reduction
Episode 0-9: Avg Reward: 24.20, Loss: 292.3097
Episode 10-19: Avg Reward: 23.40, Loss: 1.4594
...
✅ With Baseline completed successfully
🚀 Running With Entropy...
Description: VPG with baseline and entropy regularization
Episode 0-9: Avg Reward: 18.90, Loss: 66.3463
Episode 10-19: Avg Reward: 21.50, Loss: -14.5288
...
✅ With Entropy completed successfully
🚀 Running Full Enhanced...
Description: VPG with all enhancements
Episode 0-9: Avg Reward: 17.60, Loss: 13.7638
Episode 10-19: Avg Reward: 17.70, Loss: -0.1446
...
✅ Full Enhanced completed successfully
📊 Comparison plot saved to images/vpg_comparison.png
📊 Summary:
VPG (Basic): Final reward = 142.60
With Baseline: Final reward = 492.30
With Entropy: Final reward = 500.00
Full Enhanced: Final reward = 148.10
python3 main.pypython3 main.py --no-baseline --no-entropy --no-gradient-clipping --no-per-episode-norm# Just baseline subtraction
python3 main.py --no-entropy --no-gradient-clipping --no-per-episode-norm
# Custom training parameters
python3 main.py --episodes 2000 --lr 0.0005 --entropy-coef 0.02
# Headless training (no rendering)
python3 main.py --no-render --no-plot
# Quick test run
python3 main.py --episodes 200 --episodes-per-update 10 --eval-episodes 1# Run comparison of all variants defined in variants.py
python3 main.py --compare
# Run comparison with custom parameters
python3 main.py --compare --episodes 1000 --episodes-per-update 20
# List available variants
python3 main.py --list-variantspython3 main.py --helppip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
This implementation extends vanilla VPG (Vanilla Policy Gradient) with several key improvements for better training stability and performance:
Instead of using raw returns R_t in the policy gradient, we compute advantages A_t = R_t - b where b = E[R_t] is the baseline. This reduces variance in the gradient estimates:
∇_θ J(θ) = E[∇_θ log π_θ(a|s) · A_t]
Implemented as baseline = jnp.mean(jnp.array(returns)) and advantage = ret - baseline.
Added gradient clipping to prevent gradient explosion. The gradient norm is clipped to a maximum value of 1.0:
grad = grad · min(1.0, 1.0/||grad||_2)
Implemented using optax.clip_by_global_norm(1.0) in the optimizer chain.
Returns are normalized per episode to have zero mean and unit variance:
R'_t = (R_t - μ_e) / (σ_e + ε)
where μ_e and σ_e are the mean and standard deviation of returns in episode e. This prevents bias from episode length differences.
Added entropy regularization to encourage exploration. The loss function becomes:
L(θ) = -log π_θ(a|s) · A_t - β · H(π_θ)
where H(π_θ) = -Σ_a π_θ(a|s) log π_θ(a|s) is the policy entropy and β = 0.01 is the entropy coefficient.
To add your own variants for comparison, edit variants.py:
# Add custom variants
CUSTOM_VARIANTS = {
"High Entropy": {
"use_baseline": True,
"use_entropy": True,
"use_gradient_clipping": True,
"use_per_episode_norm": True,
"entropy_coef": 0.05, # Higher entropy
"description": "VPG with high entropy regularization"
}
}
def get_variants():
return CUSTOM_VARIANTS # Override to use your variants