Skip to content

Commit 4575eca

Browse files
authored
Merge pull request #113 from cpnota/release/0.3.3
Release/0.3.3
2 parents 09eca2d + c9aa786 commit 4575eca

File tree

12 files changed

+177
-60
lines changed

12 files changed

+177
-60
lines changed

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ good-names=i,
437437
n,
438438
t,
439439
e,
440-
kl
440+
kl,
441+
ax
441442

442443
# Include a hint for the correct naming format with invalid-name.
443444
include-naming-hint=no

README.md

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,62 @@
22

33
The Autonomous Learning Library (`all`) is an object-oriented deep reinforcement learning library in `pytorch`. The goal of the library is to provide implementations of modern reinforcement learning algorithms that reflect the way that reinforcement learning researchers think about agent design and to provide the components necessary to build and test new ideas with minimal overhead.
44

5-
## Algorithms
6-
7-
As of today, `all` contains implementations of the following deep RL algorithms:
8-
9-
- [x] Advantage Actor-Critic (A2C)
10-
- [x] Categorical DQN (C51)
11-
- [x] Deep Deterministic Policy Gradient (DDPG)
12-
- [x] Deep Q-Learning (DQN) + extensions
13-
- [x] Proximal Policy Optimization (PPO)
14-
- [x] Rainbow (Rainbow)
15-
- [x] Soft Actor-Critic (SAC)
16-
17-
It also contains implementations of the following "vanilla" agents, which provide useful baselines and perform better than you may expect:
18-
19-
- [x] Vanilla Actor-Critic
20-
- [x] Vanilla Policy Gradient
21-
- [x] Vanilla Q-Learning
22-
- [x] Vanilla Sarsa
23-
24-
We will try to stay up-to-date with advances in the field, but we do not intend to implement every algorithm. Rather, we prefer to maintain a smaller set of high-quality agents that have achieved notoriety in the field.
25-
265
## Why use `all`?
276

287
The primary reason for using `all` over its many competitors is because it contains components that allow you to *build your own* reinforcement learning agents.
298
We provide out-of-the-box modules for:
309

3110
- [x] Custom Q-Networks, V-Networks, policy networks, and feature networks
3211
- [x] Generic function approximation
12+
- [x] Target networks
13+
- [x] Polyak averaging
3314
- [x] Experience Replay
3415
- [x] Prioritized Experience Replay
3516
- [x] Advantage Estimation
3617
- [x] Generalized Advantage Estimation (GAE)
37-
- [x] Target networks
38-
- [x] Polyak averaging
3918
- [x] Easy parameter and learning rate scheduling
4019
- [x] An enhanced `nn` module (includes dueling layers, noisy layers, action bounds, and the coveted `nn.Flatten`)
4120
- [x] `gym` to `pytorch` wrappers
4221
- [x] Atari wrappers
4322
- [x] An `Experiment` API for comparing and evaluating agents
4423
- [x] A `SlurmExperiment` API for running massive experiments on computing clusters
4524
- [x] A `Writer` object for easily logging information in `tensorboard`
25+
- [x] Plotting utilities for generating paper-worthy result plots
4626

4727
Rather than being embedded in the agents, all of these modules are available for use by your own custom agents.
4828
Additionally, the included agents accept custom versions of any of the above objects.
4929
Have a new type of replay buffer in mind?
5030
Code it up and pass it directly to our `DQN` and `DDPG` implementations.
5131
Additionally, our agents were written with readibility as a primary concern, so they are easy to modify.
5232

33+
## Algorithms
34+
35+
As of today, `all` contains implementations of the following deep RL algorithms:
36+
37+
- [x] Advantage Actor-Critic (A2C)
38+
- [x] Categorical DQN (C51)
39+
- [x] Deep Deterministic Policy Gradient (DDPG)
40+
- [x] Deep Q-Learning (DQN) + extensions
41+
- [x] Proximal Policy Optimization (PPO)
42+
- [x] Rainbow (Rainbow)
43+
- [x] Soft Actor-Critic (SAC)
44+
45+
It also contains implementations of the following "vanilla" agents, which provide useful baselines and perform better than you may expect:
46+
47+
- [x] Vanilla Actor-Critic
48+
- [x] Vanilla Policy Gradient
49+
- [x] Vanilla Q-Learning
50+
- [x] Vanilla Sarsa
51+
52+
We will try to stay up-to-date with advances in the field, but we do not intend to implement every algorithm. Rather, we prefer to maintain a smaller set of high-quality agents that have achieved notoriety in the field.
53+
54+
We have labored to make sure that our implementations produce results comparable to published results.
55+
Here's a sampling of performance on several Atari games:
56+
57+
![atari40](atari40.png)
58+
59+
These results were generated using the `all.presets.atari` module, the `SlurmExperiment` utility, and the `all.experiments.plots` module.
60+
5361
## Example
5462

5563
Our agents implement a single method: `action = agent.act(state, reward)`.

all/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .experiment import Experiment
2+
from .plots import plot_returns_100
23
from .slurm import SlurmExperiment
34
from .watch import GreedyAgent, watch, load_and_watch
45

all/experiments/experiment_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def add_schedule(self, name, value, step="frame"):
3232
def add_evaluation(self, name, value, step="frame"):
3333
self.add_scalar("evaluation/" + name, value, self._get_step(step))
3434

35+
def add_summary(self, name, mean, std, step="frame"):
36+
pass
37+
3538
def _get_step(self, _type):
3639
if _type == "frame":
3740
return self.frames

all/experiments/plots.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
6+
def plot_returns_100(runs_dir, timesteps=-1):
7+
data = load_returns_100_data(runs_dir)
8+
lines = {}
9+
fig, axes = plt.subplots(1, len(data))
10+
for i, env in enumerate(sorted(data.keys())):
11+
ax = axes[i]
12+
subplot_returns_100(ax, env, data[env], lines, timesteps=timesteps)
13+
fig.legend(list(lines.values()), list(lines.keys()), loc="center right")
14+
plt.show()
15+
16+
17+
def load_returns_100_data(runs_dir):
18+
data = {}
19+
20+
def add_data(agent, env, file):
21+
if not env in data:
22+
data[env] = {}
23+
data[env][agent] = np.genfromtxt(file, delimiter=",").reshape((-1, 3))
24+
25+
for agent_dir in os.listdir(runs_dir):
26+
agent = agent_dir.split(" ")[0].strip("_")
27+
agent_path = os.path.join(runs_dir, agent_dir)
28+
if os.path.isdir(agent_path):
29+
for env in os.listdir(agent_path):
30+
env_path = os.path.join(agent_path, env)
31+
if os.path.isdir(env_path):
32+
returns100path = os.path.join(env_path, "returns100.csv")
33+
if os.path.exists(returns100path):
34+
add_data(agent, env, returns100path)
35+
36+
return data
37+
38+
39+
def subplot_returns_100(ax, env, data, lines, timesteps=-1):
40+
for agent in data:
41+
agent_data = data[agent]
42+
x = agent_data[:, 0]
43+
mean = agent_data[:, 1]
44+
std = agent_data[:, 2]
45+
46+
if timesteps > 0:
47+
x[-1] = timesteps
48+
49+
if agent in lines:
50+
ax.plot(x, mean, label=agent, color=lines[agent].get_color())
51+
else:
52+
line, = ax.plot(x, mean, label=agent)
53+
lines[agent] = line
54+
ax.fill_between(
55+
x, mean + std, mean - std, alpha=0.2, color=lines[agent].get_color()
56+
)
57+
ax.set_title(env)
58+
ax.set_xlabel("timesteps")
59+
ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 5))

all/experiments/runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def _log(self, returns, fps):
4444
self._best_returns = returns
4545
self._returns100.append(returns)
4646
if len(self._returns100) == 100:
47-
self._writer.add_evaluation('returns100/mean', np.mean(self._returns100), step="frame")
48-
self._writer.add_evaluation('returns100/std', np.std(self._returns100), step="frame")
47+
mean = np.mean(self._returns100)
48+
std = np.std(self._returns100)
49+
self._writer.add_summary('returns100', mean, std, step="frame")
4950
self._returns100 = []
5051
self._writer.add_evaluation('returns/episode', returns, step="episode")
5152
self._writer.add_evaluation('returns/frame', returns, step="frame")

all/logging/__init__.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
1+
import csv
22
import os
33
import subprocess
44
from abc import ABC, abstractmethod
55
from datetime import datetime
66
from tensorboardX import SummaryWriter
77

8+
89
class Writer(ABC):
9-
log_dir = 'runs'
10+
log_dir = "runs"
1011

1112
@abstractmethod
1213
def add_loss(self, name, value, step="frame"):
@@ -24,6 +25,10 @@ def add_scalar(self, name, value, step="frame"):
2425
def add_schedule(self, name, value, step="frame"):
2526
pass
2627

28+
@abstractmethod
29+
def add_summary(self, name, mean, std, step="frame"):
30+
pass
31+
2732

2833
class DummyWriter(Writer):
2934
def add_loss(self, name, value, step="frame"):
@@ -38,13 +43,21 @@ def add_scalar(self, name, value, step="frame"):
3843
def add_schedule(self, name, value, step="frame"):
3944
pass
4045

46+
def add_summary(self, name, mean, std, step="frame"):
47+
pass
48+
4149

4250
class ExperimentWriter(SummaryWriter, Writer):
4351
def __init__(self, agent_name, env_name, loss=True):
4452
self.env_name = env_name
4553
current_time = str(datetime.now())
54+
os.makedirs(
55+
os.path.join(
56+
"runs", ("%s %s %s" % (agent_name, COMMIT_HASH, current_time)), env_name
57+
)
58+
)
4659
self.log_dir = os.path.join(
47-
'runs', ("%s %s %s" % (agent_name, COMMIT_HASH, current_time))
60+
"runs", ("%s %s %s" % (agent_name, COMMIT_HASH, current_time))
4861
)
4962
self._frames = 0
5063
self._episodes = 1
@@ -56,15 +69,22 @@ def add_loss(self, name, value, step="frame"):
5669
self.add_scalar("loss/" + name, value, step)
5770

5871
def add_evaluation(self, name, value, step="frame"):
59-
self.add_scalar('evaluation/' + name, value, self._get_step(step))
72+
self.add_scalar("evaluation/" + name, value, self._get_step(step))
6073

6174
def add_schedule(self, name, value, step="frame"):
6275
if self._loss:
63-
self.add_scalar('schedule' + '/' + name, value, self._get_step(step))
76+
self.add_scalar("schedule" + "/" + name, value, self._get_step(step))
6477

6578
def add_scalar(self, name, value, step="frame"):
6679
super().add_scalar(self.env_name + "/" + name, value, self._get_step(step))
6780

81+
def add_summary(self, name, mean, std, step="frame"):
82+
self.add_evaluation(name + "/mean", mean, step)
83+
self.add_evaluation(name + "/std", std, step)
84+
85+
with open(os.path.join(self.log_dir, self.env_name, name + ".csv"), "a") as csvfile:
86+
csv.writer(csvfile).writerow([self._get_step(step), mean, std])
87+
6888
def _get_step(self, _type):
6989
if _type == "frame":
7090
return self.frames
@@ -91,13 +111,14 @@ def episodes(self, episodes):
91111

92112
def get_commit_hash():
93113
result = subprocess.run(
94-
['git', 'rev-parse', '--short', 'HEAD'], stdout=subprocess.PIPE)
95-
return result.stdout.decode('utf-8').rstrip()
114+
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE
115+
)
116+
return result.stdout.decode("utf-8").rstrip()
96117

97118

98119
COMMIT_HASH = get_commit_hash()
99120

100121
try:
101-
os.mkdir('runs')
122+
os.mkdir("runs")
102123
except FileExistsError:
103124
pass

all/presets/atari/a2c.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# /Users/cpnota/repos/autonomous-learning-library/all/approximation/value/action/torch.py
22
import torch
3-
from torch.optim import RMSprop
3+
from torch.optim import Adam
4+
from torch.optim.lr_scheduler import CosineAnnealingLR
45
from all.agents import A2C
56
from all.bodies import DeepmindAtariBody
67
from all.approximation import VNetwork, FeatureNetwork
@@ -10,42 +11,50 @@
1011

1112

1213
def a2c(
13-
# taken from stable-baselines
14+
# Common settings
15+
device=torch.device('cuda'),
1416
discount_factor=0.99,
15-
n_steps=5,
16-
value_loss_scaling=0.25,
17-
entropy_loss_scaling=0.01,
17+
last_frame=40e6,
18+
# Adam optimizer settings
19+
lr=7e-4,
20+
eps=1.5e-4,
21+
# Other optimization settings
1822
clip_grad=0.5,
19-
lr=7e-4, # RMSprop learning rate
20-
alpha=0.99, # RMSprop momentum decay
21-
eps=1e-5, # RMSprop stability
23+
entropy_loss_scaling=0.01,
24+
value_loss_scaling=0.5,
25+
# Batch settings
2226
n_envs=16,
23-
device=torch.device("cuda"),
27+
n_steps=5,
2428
):
29+
final_anneal_step = last_frame / (n_steps * n_envs * 4)
2530
def _a2c(envs, writer=DummyWriter()):
2631
env = envs[0]
2732

2833
value_model = nature_value_head().to(device)
2934
policy_model = nature_policy_head(envs[0]).to(device)
3035
feature_model = nature_features().to(device)
3136

32-
feature_optimizer = RMSprop(
33-
feature_model.parameters(), alpha=alpha, lr=lr, eps=eps
34-
)
35-
value_optimizer = RMSprop(value_model.parameters(), alpha=alpha, lr=lr, eps=eps)
36-
policy_optimizer = RMSprop(
37-
policy_model.parameters(), alpha=alpha, lr=lr, eps=eps
38-
)
37+
feature_optimizer = Adam(feature_model.parameters(), lr=lr, eps=eps)
38+
value_optimizer = Adam(value_model.parameters(), lr=lr, eps=eps)
39+
policy_optimizer = Adam(policy_model.parameters(), lr=lr, eps=eps)
3940

4041
features = FeatureNetwork(
4142
feature_model,
4243
feature_optimizer,
44+
scheduler=CosineAnnealingLR(
45+
feature_optimizer,
46+
final_anneal_step,
47+
),
4348
clip_grad=clip_grad,
4449
writer=writer
4550
)
4651
v = VNetwork(
4752
value_model,
4853
value_optimizer,
54+
scheduler=CosineAnnealingLR(
55+
value_optimizer,
56+
final_anneal_step,
57+
),
4958
loss_scaling=value_loss_scaling,
5059
clip_grad=clip_grad,
5160
writer=writer
@@ -54,6 +63,10 @@ def _a2c(envs, writer=DummyWriter()):
5463
policy_model,
5564
policy_optimizer,
5665
env.action_space.n,
66+
scheduler=CosineAnnealingLR(
67+
policy_optimizer,
68+
final_anneal_step,
69+
),
5770
entropy_loss_scaling=entropy_loss_scaling,
5871
clip_grad=clip_grad,
5972
writer=writer

0 commit comments

Comments
 (0)