Skip to content

Commit 5f11927

Browse files
authored
Merge pull request #1 from hill-a/fixes_cleanup
Fixes and cleanup
2 parents 978e116 + 3a4dcbd commit 5f11927

File tree

129 files changed

+8684
-5300
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+8684
-5300
lines changed

.coveragerc

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[run]
2+
branch = False
3+
omit =
4+
baselines/common/tests/*
5+
# Mujoco requires a licence
6+
baselines/*/run_mujoco.py
7+
baselines/ppo1/run_humanoid.py
8+
baselines/ppo1/run_robotics.py
9+
# HER requires mpi and Mujoco
10+
baselines/her/experiment/
11+
12+
[report]
13+
exclude_lines =
14+
pragma: no cover
15+
raise NotImplementedError()
16+
if KFAC_DEBUG:

.gitignore

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
*.pyc
33
*.pkl
44
*.py~
5+
*.bak
56
.pytest_cache
67
.DS_Store
78
.idea
9+
.coverage
10+
.coverage.*
11+
__pycache__/
812

913
# Setuptools distribution and build folders.
1014
/dist/
@@ -34,5 +38,3 @@ src
3438
.cache
3539

3640
MUJOCO_LOG.TXT
37-
38-

.travis.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ language: python
22
python:
33
- "3.6"
44

5+
notifications:
6+
email: false
7+
58
services:
69
- docker
710

@@ -11,4 +14,4 @@ install:
1114

1215
script:
1316
- flake8 --select=F baselines/common
14-
- docker run baselines-test pytest
17+
- docker run --env CODACY_PROJECT_TOKEN=$CODACY_PROJECT_TOKEN baselines-test sh -c 'pytest --cov-config .coveragerc --cov-report term --cov-report xml --cov=. && python-codacy-coverage -r coverage.xml --token=$CODACY_PROJECT_TOKEN'

Dockerfile

+29-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
11
FROM ubuntu:16.04
22

3-
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake
3+
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake libglib2.0-0 libsm6 libxext6 libfontconfig1 libxrender1
44
ENV CODE_DIR /root/code
55
ENV VENV /root/venv
66

7-
COPY . $CODE_DIR/baselines
87
RUN \
98
pip install virtualenv && \
109
virtualenv $VENV --python=python3 && \
1110
. $VENV/bin/activate && \
11+
mkdir $CODE_DIR && \
1212
cd $CODE_DIR && \
1313
pip install --upgrade pip && \
14-
pip install -e baselines && \
15-
pip install pytest
14+
pip install pytest && \
15+
pip install pytest-cov && \
16+
pip install codacy-coverage && \
17+
pip install scipy && \
18+
pip install tqdm && \
19+
pip install joblib && \
20+
pip install zmq && \
21+
pip install dill && \
22+
pip install progressbar2 && \
23+
pip install mpi4py && \
24+
pip install cloudpickle && \
25+
pip install tensorflow>=1.4.0 && \
26+
pip install click && \
27+
pip install opencv-python && \
28+
pip install numpy && \
29+
pip install pandas && \
30+
pip install pytest && \
31+
pip install matplotlib && \
32+
pip install seaborn && \
33+
pip install glob2 && \
34+
pip install gym[mujoco,atari,classic_control,robotics]
35+
36+
COPY . $CODE_DIR/baselines
37+
RUN \
38+
. $VENV/bin/activate && \
39+
cd $CODE_DIR && \
40+
pip install -e baselines
1641

1742
ENV PATH=$VENV/bin:$PATH
1843
WORKDIR $CODE_DIR/baselines

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines)
1+
<img src="data/logo.jpg" width=25% align="right" /> [![Build Status](https://travis-ci.org/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.org/hill-a/stable-baselines) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=hill-a/stable-baselines&amp;utm_campaign=Badge_Grade) [![Codacy Badge](https://api.codacy.com/project/badge/Coverage/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&utm_medium=referral&utm_content=hill-a/stable-baselines&utm_campaign=Badge_Coverage)
22

33
# Baselines
44

baselines/a2c/a2c.py

+115-60
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,91 @@
1-
import os.path as osp
1+
import os
22
import time
33
import joblib
4+
45
import numpy as np
56
import tensorflow as tf
6-
from baselines import logger
77

8-
from baselines.common import set_global_seeds, explained_variance
8+
from baselines import logger
9+
from baselines.common import set_global_seeds, explained_variance, tf_util
910
from baselines.common.runners import AbstractEnvRunner
10-
from baselines.common import tf_util
11+
from baselines.a2c.utils import discount_with_dones, Scheduler, make_path, find_trainable_variables, calc_entropy, mse
1112

12-
from baselines.a2c.utils import discount_with_dones
13-
from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables
14-
from baselines.a2c.utils import cat_entropy, mse
1513

1614
class Model(object):
17-
18-
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps,
19-
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
20-
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
15+
def __init__(self, policy, ob_space, ac_space, n_envs, n_steps,
16+
ent_coef=0.01, vf_coef=0.25, max_grad_norm=0.5, learning_rate=7e-4,
17+
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lr_schedule='linear'):
18+
"""
19+
The A2C (Advantage Actor Critic) model class, https://arxiv.org/abs/1602.01783
20+
21+
:param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
22+
:param ob_space: (Gym Space) Observation space
23+
:param ac_space: (Gym Space) Action space
24+
:param n_envs: (int) The number of environments
25+
:param n_steps: (int) The number of steps to run for each environment
26+
:param ent_coef: (float) Entropy coefficient for the loss caculation
27+
:param vf_coef: (float) Value function coefficient for the loss calculation
28+
:param max_grad_norm: (float) The maximum value for the gradient clipping
29+
:param learning_rate: (float) The learning rate
30+
:param alpha: (float) RMS prop optimizer decay
31+
:param epsilon: (float) RMS prop optimizer epsilon
32+
:param total_timesteps: (int) The total number of samples
33+
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
34+
'double_linear_con', 'middle_drop' or 'double_middle_drop')
35+
"""
2136

2237
sess = tf_util.make_session()
23-
nbatch = nenvs*nsteps
38+
n_batch = n_envs * n_steps
2439

25-
A = tf.placeholder(tf.int32, [nbatch])
26-
ADV = tf.placeholder(tf.float32, [nbatch])
27-
R = tf.placeholder(tf.float32, [nbatch])
28-
LR = tf.placeholder(tf.float32, [])
40+
actions_ph = tf.placeholder(tf.int32, [n_batch])
41+
advs_ph = tf.placeholder(tf.float32, [n_batch])
42+
rewards_ph = tf.placeholder(tf.float32, [n_batch])
43+
learning_rate_ph = tf.placeholder(tf.float32, [])
2944

30-
step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False)
31-
train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True)
45+
step_model = policy(sess, ob_space, ac_space, n_envs, 1, reuse=False)
46+
train_model = policy(sess, ob_space, ac_space, n_envs * n_steps, n_steps, reuse=True)
3247

33-
neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A)
34-
pg_loss = tf.reduce_mean(ADV * neglogpac)
35-
vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R))
36-
entropy = tf.reduce_mean(cat_entropy(train_model.pi))
37-
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
48+
neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.policy, labels=actions_ph)
49+
pg_loss = tf.reduce_mean(advs_ph * neglogpac)
50+
vf_loss = mse(tf.squeeze(train_model.value_fn), rewards_ph)
51+
entropy = tf.reduce_mean(calc_entropy(train_model.policy))
52+
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
3853

3954
params = find_trainable_variables("model")
4055
grads = tf.gradients(loss, params)
4156
if max_grad_norm is not None:
42-
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
57+
grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
4358
grads = list(zip(grads, params))
44-
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
59+
trainer = tf.train.RMSPropOptimizer(learning_rate=learning_rate_ph, decay=alpha, epsilon=epsilon)
4560
_train = trainer.apply_gradients(grads)
4661

47-
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
62+
learning_rate = Scheduler(initial_value=learning_rate, n_values=total_timesteps, schedule=lr_schedule)
4863

4964
def train(obs, states, rewards, masks, actions, values):
5065
advs = rewards - values
51-
for step in range(len(obs)):
52-
cur_lr = lr.value()
53-
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
66+
for _ in range(len(obs)):
67+
cur_lr = learning_rate.value()
68+
td_map = {train_model.obs_ph: obs, actions_ph: actions, advs_ph: advs,
69+
rewards_ph: rewards, learning_rate_ph: cur_lr}
5470
if states is not None:
55-
td_map[train_model.S] = states
56-
td_map[train_model.M] = masks
71+
td_map[train_model.states_ph] = states
72+
td_map[train_model.masks_ph] = masks
5773
policy_loss, value_loss, policy_entropy, _ = sess.run(
5874
[pg_loss, vf_loss, entropy, _train],
5975
td_map
6076
)
6177
return policy_loss, value_loss, policy_entropy
6278

6379
def save(save_path):
64-
ps = sess.run(params)
65-
make_path(osp.dirname(save_path))
66-
joblib.dump(ps, save_path)
80+
parameters = sess.run(params)
81+
make_path(os.path.dirname(save_path))
82+
joblib.dump(parameters, save_path)
6783

6884
def load(load_path):
6985
loaded_params = joblib.load(load_path)
7086
restores = []
71-
for p, loaded_p in zip(params, loaded_params):
72-
restores.append(p.assign(loaded_p))
87+
for param, loaded_p in zip(params, loaded_params):
88+
restores.append(param.assign(loaded_p))
7389
sess.run(restores)
7490

7591
self.train = train
@@ -82,16 +98,30 @@ def load(load_path):
8298
self.load = load
8399
tf.global_variables_initializer().run(session=sess)
84100

85-
class Runner(AbstractEnvRunner):
86101

87-
def __init__(self, env, model, nsteps=5, gamma=0.99):
88-
super().__init__(env=env, model=model, nsteps=nsteps)
102+
class Runner(AbstractEnvRunner):
103+
def __init__(self, env, model, n_steps=5, gamma=0.99):
104+
"""
105+
A runner to learn the policy of an environment for a model
106+
107+
:param env: (Gym environment) The environment to learn from
108+
:param model: (Model) The model to learn
109+
:param n_steps: (int) The number of steps to run for each environment
110+
:param gamma: (float) Discount factor
111+
"""
112+
super(Runner, self).__init__(env=env, model=model, n_steps=n_steps)
89113
self.gamma = gamma
90114

91115
def run(self):
92-
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
116+
"""
117+
Run a learning step of the model
118+
119+
:return: ([float], [float], [float], [bool], [float], [float])
120+
observations, states, rewards, masks, actions, values
121+
"""
122+
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [], [], [], [], []
93123
mb_states = self.states
94-
for n in range(self.nsteps):
124+
for _ in range(self.n_steps):
95125
actions, values, states, _ = self.model.step(self.obs, self.states, self.dones)
96126
mb_obs.append(np.copy(self.obs))
97127
mb_actions.append(actions)
@@ -102,11 +132,11 @@ def run(self):
102132
self.dones = dones
103133
for n, done in enumerate(dones):
104134
if done:
105-
self.obs[n] = self.obs[n]*0
135+
self.obs[n] = self.obs[n] * 0
106136
self.obs = obs
107137
mb_rewards.append(rewards)
108138
mb_dones.append(self.dones)
109-
#batch of steps to batch of rollouts
139+
# batch of steps to batch of rollouts
110140
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape)
111141
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
112142
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
@@ -115,12 +145,12 @@ def run(self):
115145
mb_masks = mb_dones[:, :-1]
116146
mb_dones = mb_dones[:, 1:]
117147
last_values = self.model.value(self.obs, self.states, self.dones).tolist()
118-
#discount/bootstrap off value fn
148+
# discount/bootstrap off value fn
119149
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
120150
rewards = rewards.tolist()
121151
dones = dones.tolist()
122152
if dones[-1] == 0:
123-
rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
153+
rewards = discount_with_dones(rewards + [value], dones + [0], self.gamma)[:-1]
124154
else:
125155
rewards = discount_with_dones(rewards, dones, self.gamma)
126156
mb_rewards[n] = rewards
@@ -130,31 +160,56 @@ def run(self):
130160
mb_masks = mb_masks.flatten()
131161
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
132162

133-
def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100):
163+
164+
def learn(policy, env, seed, n_steps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5,
165+
learning_rate=7e-4, lr_schedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100):
166+
"""
167+
Return a trained A2C model.
168+
169+
:param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
170+
:param env: (Gym environment) The environment to learn from
171+
:param seed: (int) The initial seed for training
172+
:param n_steps: (int) The number of steps to run for each environment
173+
:param total_timesteps: (int) The total number of samples
174+
:param vf_coef: (float) Value function coefficient for the loss calculation
175+
:param ent_coef: (float) Entropy coefficient for the loss caculation
176+
:param max_grad_norm: (float) The maximum value for the gradient clipping
177+
:param learning_rate: (float) The learning rate
178+
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
179+
'double_linear_con', 'middle_drop' or 'double_middle_drop')
180+
:param epsilon: (float) RMS prop optimizer epsilon
181+
:param alpha: (float) RMS prop optimizer decay
182+
:param gamma: (float) Discount factor
183+
:param log_interval: (int) The number of timesteps before logging.
184+
:return: (Model) A2C model
185+
"""
134186
set_global_seeds(seed)
135187

136-
nenvs = env.num_envs
188+
n_envs = env.num_envs
137189
ob_space = env.observation_space
138190
ac_space = env.action_space
139-
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
140-
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
141-
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
142-
143-
nbatch = nenvs*nsteps
144-
tstart = time.time()
145-
for update in range(1, total_timesteps//nbatch+1):
191+
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, n_envs=n_envs,
192+
n_steps=n_steps, ent_coef=ent_coef,
193+
vf_coef=vf_coef, max_grad_norm=max_grad_norm, learning_rate=learning_rate,
194+
alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps,
195+
lr_schedule=lr_schedule)
196+
runner = Runner(env, model, n_steps=n_steps, gamma=gamma)
197+
198+
n_batch = n_envs * n_steps
199+
t_start = time.time()
200+
for update in range(1, total_timesteps // n_batch + 1):
146201
obs, states, rewards, masks, actions, values = runner.run()
147-
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
148-
nseconds = time.time()-tstart
149-
fps = int((update*nbatch)/nseconds)
202+
_, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
203+
n_seconds = time.time() - t_start
204+
fps = int((update * n_batch) / n_seconds)
150205
if update % log_interval == 0 or update == 1:
151-
ev = explained_variance(values, rewards)
206+
explained_var = explained_variance(values, rewards)
152207
logger.record_tabular("nupdates", update)
153-
logger.record_tabular("total_timesteps", update*nbatch)
208+
logger.record_tabular("total_timesteps", update * n_batch)
154209
logger.record_tabular("fps", fps)
155210
logger.record_tabular("policy_entropy", float(policy_entropy))
156211
logger.record_tabular("value_loss", float(value_loss))
157-
logger.record_tabular("explained_variance", float(ev))
212+
logger.record_tabular("explained_variance", float(explained_var))
158213
logger.dump_tabular()
159214
env.close()
160215
return model

0 commit comments

Comments
 (0)