Skip to content

Commit 531750f

Browse files
committed
Add CartPole examples and update README
1 parent 51fea9d commit 531750f

20 files changed

+1865
-111
lines changed

OptimRL.code-workspace

Lines changed: 0 additions & 8 deletions
This file was deleted.

README.md

Lines changed: 142 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ OptimRL is a **high-performance reinforcement learning library** that introduces
1111
![PyTorch](https://img.shields.io/badge/Framework-PyTorch-EE4C2C?logo=pytorch&logoColor=white)
1212
![Setuptools](https://img.shields.io/badge/Tool-Setuptools-3776AB?logo=python&logoColor=white)
1313
![Build Status](https://github.com/subaashnair/optimrl/actions/workflows/tests.yml/badge.svg)
14-
![CI](https://github.com/subaashnair/optimrl/workflows/CI/badge.svg)
15-
![Coverage](https://img.shields.io/codecov/c/github/subaashnair/optimrl)
1614
![License](https://img.shields.io/github/license/subaashnair/optimrl)
15+
<!-- ![Coverage](https://img.shields.io/codecov/c/github/subaashnair/optimrl) -->
16+
1717

1818
## 🌟 Features
1919

@@ -39,6 +39,18 @@ OptimRL is a **high-performance reinforcement learning library** that introduces
3939
- Native integration with deep learning workflows
4040
- Full automatic differentiation support
4141

42+
5. **🔄 Experience Replay Buffer**
43+
Improve sample efficiency with built-in experience replay:
44+
- Learn from past experiences multiple times
45+
- Reduce correlation between consecutive samples
46+
- Configurable buffer capacity and batch sizes
47+
48+
6. **🔄 Continuous Action Space Support**
49+
Train agents in environments with continuous control:
50+
- Gaussian policy implementation for continuous actions
51+
- Configurable action bounds
52+
- Adaptive standard deviation for exploration
53+
4254
---
4355

4456
## 🛠️ Installation
@@ -61,95 +73,156 @@ pip install -e '.[dev]'
6173

6274
## ⚡ Quick Start
6375

64-
Here’s a **minimal working example** to get started with OptimRL:
76+
### Discrete Action Space Example (CartPole)
6577

6678
```python
6779
import torch
68-
import optimrl
69-
70-
# Initialize the GRPO optimizer
71-
grpo = optimrl.GRPO(epsilon=0.2, beta=0.1)
80+
import torch.nn as nn
81+
import torch.optim as optim
82+
import gym
83+
from optimrl import create_agent
7284

73-
# Prepare batch data (example)
74-
batch_data = {
75-
'log_probs_old': current_policy_log_probs,
76-
'log_probs_ref': reference_policy_log_probs,
77-
'rewards': episode_rewards,
78-
'group_size': len(episode_rewards)
79-
}
85+
# Define a simple policy network
86+
class PolicyNetwork(nn.Module):
87+
def __init__(self, input_dim, output_dim):
88+
super().__init__()
89+
self.network = nn.Sequential(
90+
nn.Linear(input_dim, 64),
91+
nn.ReLU(),
92+
nn.Linear(64, output_dim),
93+
nn.LogSoftmax(dim=-1)
94+
)
95+
96+
def forward(self, x):
97+
return self.network(x)
8098

81-
# Compute policy loss
82-
log_probs_new = new_policy_log_probs
83-
loss, gradients = grpo.compute_loss(batch_data, log_probs_new)
99+
# Create environment and network
100+
env = gym.make('CartPole-v1')
101+
state_dim = env.observation_space.shape[0]
102+
action_dim = env.action_space.n
103+
policy = PolicyNetwork(state_dim, action_dim)
104+
105+
# Create GRPO agent
106+
agent = create_agent(
107+
"grpo",
108+
policy_network=policy,
109+
optimizer_class=optim.Adam,
110+
learning_rate=0.001,
111+
gamma=0.99,
112+
grpo_params={"epsilon": 0.2, "beta": 0.01},
113+
buffer_capacity=10000,
114+
batch_size=32
115+
)
84116

85-
# Apply gradients to update the policy
86-
optimizer.zero_grad()
87-
policy_loss = torch.tensor(loss, requires_grad=True)
88-
policy_loss.backward()
89-
optimizer.step()
117+
# Training loop
118+
state, _ = env.reset()
119+
for step in range(1000):
120+
action = agent.act(state)
121+
next_state, reward, done, truncated, _ = env.step(action)
122+
agent.store_experience(reward, done)
123+
124+
if done or truncated:
125+
state, _ = env.reset()
126+
agent.update() # Update policy after episode ends
127+
else:
128+
state = next_state
90129
```
91130

92-
---
131+
### Complete CartPole Implementation
132+
133+
For a complete implementation of CartPole with OptimRL, check out our examples in the `simple_test` directory:
93134

94-
## 🔍 Advanced Usage
135+
- `cartpole_simple.py`: Basic implementation with GRPO
136+
- `cartpole_improved.py`: Improved implementation with tuned parameters
137+
- `cartpole_final.py`: Final implementation with optimized performance
138+
- `cartpole_tuned.py`: Enhanced implementation with advanced features
139+
- `cartpole_simple_pg.py`: Vanilla Policy Gradient implementation for comparison
95140

96-
Integrate OptimRL seamlessly into your **PyTorch pipelines** or custom training loops. Below is a **complete example** showcasing GRPO in action:
141+
The vanilla policy gradient implementation (`cartpole_simple_pg.py`) achieves excellent performance on CartPole-v1, reaching the maximum reward of 500 consistently. It serves as a useful baseline for comparing against the GRPO implementations.
142+
143+
### Continuous Action Space Example (Pendulum)
97144

98145
```python
99146
import torch
100-
import optimrl
101-
102-
class PolicyNetwork(torch.nn.Module):
103-
def __init__(self, input_dim, output_dim):
147+
import torch.nn as nn
148+
import torch.optim as optim
149+
import gym
150+
from optimrl import create_agent
151+
152+
# Define a continuous policy network
153+
class ContinuousPolicyNetwork(nn.Module):
154+
def __init__(self, input_dim, action_dim):
104155
super().__init__()
105-
self.network = torch.nn.Sequential(
106-
torch.nn.Linear(input_dim, 64),
107-
torch.nn.Tanh(),
108-
torch.nn.Linear(64, output_dim),
109-
torch.nn.LogSoftmax(dim=-1)
156+
self.shared_layers = nn.Sequential(
157+
nn.Linear(input_dim, 64),
158+
nn.ReLU(),
159+
nn.Linear(64, 64),
160+
nn.ReLU()
110161
)
111-
162+
# Output both mean and log_std for each action dimension
163+
self.output_layer = nn.Linear(64, action_dim * 2)
164+
112165
def forward(self, x):
113-
return self.network(x)
114-
115-
# Initialize components
116-
policy = PolicyNetwork(input_dim=4, output_dim=2)
117-
reference_policy = PolicyNetwork(input_dim=4, output_dim=2)
118-
optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)
119-
grpo = optimrl.GRPO(epsilon=0.2, beta=0.1)
166+
x = self.shared_layers(x)
167+
return self.output_layer(x)
168+
169+
# Create environment and network
170+
env = gym.make('Pendulum-v1')
171+
state_dim = env.observation_space.shape[0]
172+
action_dim = env.action_space.shape[0]
173+
action_bounds = (env.action_space.low[0], env.action_space.high[0])
174+
policy = ContinuousPolicyNetwork(state_dim, action_dim)
175+
176+
# Create Continuous GRPO agent
177+
agent = create_agent(
178+
"continuous_grpo",
179+
policy_network=policy,
180+
optimizer_class=optim.Adam,
181+
action_dim=action_dim,
182+
learning_rate=0.0005,
183+
gamma=0.99,
184+
grpo_params={"epsilon": 0.2, "beta": 0.01},
185+
buffer_capacity=10000,
186+
batch_size=64,
187+
min_std=0.01,
188+
action_bounds=action_bounds
189+
)
120190

121191
# Training loop
122-
for episode in range(1000): # Replace with your num_episodes
123-
states, actions, rewards = collect_episode() # Replace with your data
192+
state, _ = env.reset()
193+
for step in range(1000):
194+
action = agent.act(state)
195+
next_state, reward, done, truncated, _ = env.step(action)
196+
agent.store_experience(reward, done)
124197

125-
# Compute log probabilities
126-
with torch.no_grad():
127-
log_probs_old = policy(states)
128-
log_probs_ref = reference_policy(states)
129-
130-
batch_data = {
131-
'log_probs_old': log_probs_old.numpy(),
132-
'log_probs_ref': log_probs_ref.numpy(),
133-
'rewards': rewards,
134-
'group_size': len(rewards)
135-
}
136-
137-
# Policy update
138-
log_probs_new = policy(states)
139-
loss, gradients = grpo.compute_loss(batch_data, log_probs_new.numpy())
140-
141-
# Backpropagation
142-
optimizer.zero_grad()
143-
policy_loss = torch.tensor(loss, requires_grad=True)
144-
policy_loss.backward()
145-
optimizer.step()
198+
if done or truncated:
199+
state, _ = env.reset()
200+
agent.update() # Update policy after episode ends
201+
else:
202+
state = next_state
146203
```
147204

205+
## 📊 Performance Comparison
206+
207+
Our simple policy gradient implementation consistently solves the CartPole-v1 environment in under 1000 episodes, achieving the maximum reward of 500. The GRPO implementations offer competitive performance with additional benefits:
208+
209+
- **Lower variance**: More stable learning across different random seeds
210+
- **Improved sample efficiency**: Learns from fewer interactions with the environment
211+
- **Better regularization**: Prevents policy collapse during training
212+
213+
## Kaggle Notebook
214+
215+
You can view the "OptimRL Trading Experiment" notebook on Kaggle:
216+
[![OptimRL Trading Experiment](https://img.shields.io/badge/Kaggle-OptimRL_Trading_Experiment-orange)](https://www.kaggle.com/code/noir1112/optimrl-trading-experiment/edit)
217+
218+
Alternatively, you can open the notebook locally as an `.ipynb` file:
219+
[Open the OptimRL Trading Experiment Notebook (.ipynb)](./notebooks/OptimRL_Trading_Experiment.ipynb)
220+
148221
---
149222

150223
## 🤝 Contributing
151224

152-
Were excited to have you onboard! Heres how you can help improve **OptimRL**:
225+
We're excited to have you onboard! Here's how you can help improve **OptimRL**:
153226
1. **Fork the repo.**
154227
2. **Create a feature branch**:
155228
```bash
@@ -185,7 +258,7 @@ If you use OptimRL in your research, please cite:
185258
```bibtex
186259
@software{optimrl2024,
187260
title={OptimRL: Group Relative Policy Optimization},
188-
author={Your Name},
261+
author={Subashan Nair},
189262
year={2024},
190263
url={https://github.com/subaashnair/optimrl}
191264
}
@@ -194,3 +267,5 @@ If you use OptimRL in your research, please cite:
194267
---
195268

196269

270+
271+

cartpole_rewards_final.png

103 KB
Loading

cartpole_rewards_improved.png

59.3 KB
Loading

cartpole_rewards_simple.png

53.3 KB
Loading

cartpole_training_progress.png

98.7 KB
Loading

cartpole_training_progress_pg.png

91.3 KB
Loading

examples/cartpole_example.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python
2+
# Example of training a GRPO agent on the CartPole environment
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
import gym
8+
import numpy as np
9+
import matplotlib.pyplot as plt
10+
from optimrl import GRPO, GRPOAgent, create_agent
11+
12+
# Define a simple policy network for CartPole
13+
class PolicyNetwork(nn.Module):
14+
def __init__(self, input_dim, output_dim):
15+
super().__init__()
16+
self.network = nn.Sequential(
17+
nn.Linear(input_dim, 64),
18+
nn.ReLU(),
19+
nn.Linear(64, 64),
20+
nn.ReLU(),
21+
nn.Linear(64, output_dim),
22+
nn.LogSoftmax(dim=-1)
23+
)
24+
25+
def forward(self, x):
26+
return self.network(x)
27+
28+
def train_cartpole(episodes=500, render=False):
29+
# Create the CartPole environment
30+
env = gym.make('CartPole-v1')
31+
32+
# Get environment dimensions
33+
state_dim = env.observation_space.shape[0] # 4 for CartPole
34+
action_dim = env.action_space.n # 2 for CartPole
35+
36+
# Create the policy network
37+
policy_network = PolicyNetwork(state_dim, action_dim)
38+
39+
# Initialize the GRPO agent
40+
agent = create_agent(
41+
"grpo",
42+
policy_network=policy_network,
43+
optimizer_class=optim.Adam,
44+
learning_rate=0.001,
45+
gamma=0.99,
46+
grpo_params={"epsilon": 0.2, "beta": 0.01},
47+
buffer_capacity=10000,
48+
batch_size=32
49+
)
50+
51+
# Training loop
52+
rewards_history = []
53+
54+
for episode in range(episodes):
55+
state, _ = env.reset()
56+
episode_reward = 0
57+
done = False
58+
59+
while not done:
60+
if render and episode % 50 == 0:
61+
env.render()
62+
63+
# Select an action
64+
action = agent.act(state)
65+
66+
# Take the action in the environment
67+
next_state, reward, done, truncated, _ = env.step(action)
68+
done = done or truncated
69+
70+
# Store experience and update policy
71+
agent.store_experience(reward, done)
72+
73+
# Update state and reward
74+
state = next_state
75+
episode_reward += reward
76+
77+
# Update policy after episode ends
78+
agent.update()
79+
80+
# Record rewards
81+
rewards_history.append(episode_reward)
82+
83+
# Print progress
84+
if (episode + 1) % 10 == 0:
85+
avg_reward = np.mean(rewards_history[-10:])
86+
print(f"Episode {episode + 1}/{episodes} | Avg Reward: {avg_reward:.2f}")
87+
88+
env.close()
89+
90+
# Plot rewards
91+
plt.figure(figsize=(10, 6))
92+
plt.plot(rewards_history)
93+
plt.xlabel('Episode')
94+
plt.ylabel('Total Reward')
95+
plt.title('GRPO on CartPole-v1')
96+
plt.savefig('cartpole_rewards.png')
97+
plt.show()
98+
99+
return rewards_history, policy_network
100+
101+
if __name__ == "__main__":
102+
rewards, model = train_cartpole(episodes=300, render=False)
103+
print("Training completed!")
104+
105+
# Save the trained model
106+
torch.save(model.state_dict(), "cartpole_grpo_model.pt")
107+
print("Model saved to cartpole_grpo_model.pt")

0 commit comments

Comments
 (0)