Skip to content

Commit b237494

Browse files
authored
Prioritized DQN (#30)
* add sum_tree.py * add prioritized replay buffer * del sum_tree.py * fix some format issues * fix weight_update bug * simply replace replaybuffer in test_dqn without weight update * weight default set to 1 * fix sampling bug when buffer is not full * rename parameter * fix formula error, add accuracy check * add PrioritizedDQN test * add test_pdqn.py * add update_weight() doc * add ref of prio dqn in readme.md and index.rst * restore test_dqn.py, fix args of test_pdqn.py
1 parent 7029034 commit b237494

File tree

6 files changed

+263
-14
lines changed

6 files changed

+263
-14
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
2121
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
2222
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
23+
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
2324
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
2425
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
2526
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Welcome to Tianshou!
1111
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
1212
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
1313
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
14+
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
1415
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
1516
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
1617
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
from tianshou.data import PrioritizedReplayBuffer
3+
4+
if __name__ == '__main__':
5+
from env import MyTestEnv
6+
else: # pytest
7+
from test.base.env import MyTestEnv
8+
9+
10+
def test_replaybuffer(size=32, bufsize=15):
11+
env = MyTestEnv(size)
12+
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
13+
obs = env.reset()
14+
action_list = [1] * 5 + [0] * 10 + [1] * 10
15+
for i, a in enumerate(action_list):
16+
obs_next, rew, done, info = env.step(a)
17+
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
18+
obs = obs_next
19+
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
20+
rtol=1e-12)
21+
data, indice = buf.sample(len(buf) // 2)
22+
if len(buf)//2 == 0:
23+
assert len(data) == len(buf)
24+
else:
25+
assert len(data) == len(buf)//2
26+
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
27+
assert np.isclose(buf._weight_sum, (buf.weight).sum())
28+
data, indice = buf.sample(len(buf) // 2)
29+
buf.update_weight(indice, -data.weight/2)
30+
assert np.isclose(buf.weight[indice], np.power(
31+
np.abs(-data.weight/2), buf._alpha)).all()
32+
assert np.isclose(buf._weight_sum, (buf.weight).sum())
33+
34+
35+
if __name__ == "__main__":
36+
test_replaybuffer(233333, 200000)
37+
print("pass")

test/discrete/test_pdqn.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
import gym
3+
import torch
4+
import pprint
5+
import argparse
6+
import numpy as np
7+
from torch.utils.tensorboard import SummaryWriter
8+
9+
from tianshou.env import VectorEnv
10+
from tianshou.policy import DQNPolicy
11+
from tianshou.trainer import offpolicy_trainer
12+
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
13+
14+
if __name__ == '__main__':
15+
from net import Net
16+
else: # pytest
17+
from test.discrete.net import Net
18+
19+
20+
def get_args():
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument('--task', type=str, default='CartPole-v0')
23+
parser.add_argument('--seed', type=int, default=1626)
24+
parser.add_argument('--eps-test', type=float, default=0.05)
25+
parser.add_argument('--eps-train', type=float, default=0.1)
26+
parser.add_argument('--buffer-size', type=int, default=20000)
27+
parser.add_argument('--lr', type=float, default=1e-3)
28+
parser.add_argument('--gamma', type=float, default=0.9)
29+
parser.add_argument('--n-step', type=int, default=3)
30+
parser.add_argument('--target-update-freq', type=int, default=320)
31+
parser.add_argument('--epoch', type=int, default=10)
32+
parser.add_argument('--step-per-epoch', type=int, default=1000)
33+
parser.add_argument('--collect-per-step', type=int, default=10)
34+
parser.add_argument('--batch-size', type=int, default=64)
35+
parser.add_argument('--layer-num', type=int, default=3)
36+
parser.add_argument('--training-num', type=int, default=8)
37+
parser.add_argument('--test-num', type=int, default=100)
38+
parser.add_argument('--logdir', type=str, default='log')
39+
parser.add_argument('--render', type=float, default=0.)
40+
parser.add_argument('--prioritized-replay', type=int, default=1)
41+
parser.add_argument('--alpha', type=float, default=0.5)
42+
parser.add_argument('--beta', type=float, default=0.5)
43+
parser.add_argument(
44+
'--device', type=str,
45+
default='cuda' if torch.cuda.is_available() else 'cpu')
46+
args = parser.parse_known_args()[0]
47+
return args
48+
49+
50+
def test_pdqn(args=get_args()):
51+
env = gym.make(args.task)
52+
args.state_shape = env.observation_space.shape or env.observation_space.n
53+
args.action_shape = env.action_space.shape or env.action_space.n
54+
# train_envs = gym.make(args.task)
55+
# you can also use tianshou.env.SubprocVectorEnv
56+
train_envs = VectorEnv(
57+
[lambda: gym.make(args.task) for _ in range(args.training_num)])
58+
# test_envs = gym.make(args.task)
59+
test_envs = VectorEnv(
60+
[lambda: gym.make(args.task) for _ in range(args.test_num)])
61+
# seed
62+
np.random.seed(args.seed)
63+
torch.manual_seed(args.seed)
64+
train_envs.seed(args.seed)
65+
test_envs.seed(args.seed)
66+
# model
67+
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
68+
net = net.to(args.device)
69+
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
70+
policy = DQNPolicy(
71+
net, optim, args.gamma, args.n_step,
72+
use_target_network=args.target_update_freq > 0,
73+
target_update_freq=args.target_update_freq)
74+
# collector
75+
if args.prioritized_replay > 0:
76+
buf = PrioritizedReplayBuffer(
77+
args.buffer_size, alpha=args.alpha, beta=args.alpha)
78+
else:
79+
buf = ReplayBuffer(args.buffer_size)
80+
train_collector = Collector(
81+
policy, train_envs, buf)
82+
test_collector = Collector(policy, test_envs)
83+
# policy.set_eps(1)
84+
train_collector.collect(n_step=args.batch_size)
85+
# log
86+
log_path = os.path.join(args.logdir, args.task, 'dqn')
87+
writer = SummaryWriter(log_path)
88+
89+
def save_fn(policy):
90+
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
91+
92+
def stop_fn(x):
93+
return x >= env.spec.reward_threshold
94+
95+
def train_fn(x):
96+
policy.set_eps(args.eps_train)
97+
98+
def test_fn(x):
99+
policy.set_eps(args.eps_test)
100+
101+
# trainer
102+
result = offpolicy_trainer(
103+
policy, train_collector, test_collector, args.epoch,
104+
args.step_per_epoch, args.collect_per_step, args.test_num,
105+
args.batch_size, train_fn=train_fn, test_fn=test_fn,
106+
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
107+
108+
assert stop_fn(result['best_reward'])
109+
train_collector.close()
110+
test_collector.close()
111+
if __name__ == '__main__':
112+
pprint.pprint(result)
113+
# Let's watch its performance!
114+
env = gym.make(args.task)
115+
collector = Collector(policy, env)
116+
result = collector.collect(n_episode=1, render=args.render)
117+
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
118+
collector.close()
119+
120+
121+
if __name__ == '__main__':
122+
test_pdqn(get_args())

tianshou/data/buffer.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,87 @@ def reset(self):
258258
class PrioritizedReplayBuffer(ReplayBuffer):
259259
"""docstring for PrioritizedReplayBuffer"""
260260

261-
def __init__(self, size, **kwargs):
261+
def __init__(self, size, alpha: float, beta: float,
262+
mode: str = 'weight', **kwargs):
263+
if mode != 'weight':
264+
raise NotImplementedError
262265
super().__init__(size, **kwargs)
266+
self._alpha = alpha # prioritization exponent
267+
self._beta = beta # importance sample soft coefficient
268+
self._weight_sum = 0.0
269+
self.weight = np.zeros(size, dtype=np.float64)
270+
self._amortization_freq = 50
271+
self._amortization_counter = 0
272+
273+
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0):
274+
"""Add a batch of data into replay buffer."""
275+
self._weight_sum += np.abs(weight)**self._alpha - \
276+
self.weight[self._index]
277+
# we have to sacrifice some convenience for speed :(
278+
self._add_to_buffer('weight', np.abs(weight)**self._alpha)
279+
super().add(obs, act, rew, done, obs_next, info)
280+
self._check_weight_sum()
281+
282+
def sample(self, batch_size: int = 0, importance_sample: bool = True):
283+
""" Get a random sample from buffer with priority probability. \
284+
Return all the data in the buffer if batch_size is ``0``.
263285
264-
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
265-
raise NotImplementedError
266-
267-
def sample(self, batch_size):
268-
raise NotImplementedError
286+
:return: Sample data and its corresponding index inside the buffer.
287+
"""
288+
if batch_size > 0 and batch_size <= self._size:
289+
# Multiple sampling of the same sample
290+
# will cause weight update conflict
291+
indice = np.random.choice(
292+
self._size, batch_size,
293+
p=(self.weight/self.weight.sum())[:self._size], replace=False)
294+
# self._weight_sum is not work for the accuracy issue
295+
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
296+
elif batch_size == 0:
297+
indice = np.concatenate([
298+
np.arange(self._index, self._size),
299+
np.arange(0, self._index),
300+
])
301+
else:
302+
# if batch_size larger than len(self),
303+
# it will lead to a bug in update weight
304+
raise ValueError("batch_size should be less than len(self)")
305+
batch = self[indice]
306+
if importance_sample:
307+
impt_weight = Batch(
308+
impt_weight=1/np.power(
309+
self._size*(batch.weight/self._weight_sum), self._beta))
310+
batch.append(impt_weight)
311+
self._check_weight_sum()
312+
return batch, indice
269313

270314
def reset(self):
271-
raise NotImplementedError
315+
self._amortization_counter = 0
316+
super().reset()
317+
318+
def update_weight(self, indice, new_weight: np.ndarray):
319+
"""update priority weight by indice in this buffer
320+
321+
:param indice: indice you want to update weight
322+
:param new_weight: new priority weight you wangt to update
323+
"""
324+
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
325+
- self.weight[indice].sum()
326+
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
327+
328+
def __getitem__(self, index):
329+
return Batch(
330+
obs=self.get(index, 'obs'),
331+
act=self.act[index],
332+
rew=self.rew[index],
333+
done=self.done[index],
334+
obs_next=self.get(index, 'obs_next'),
335+
info=self.info[index],
336+
weight=self.weight[index]
337+
)
338+
339+
def _check_weight_sum(self):
340+
# keep a accurate _weight_sum
341+
self._amortization_counter += 1
342+
if self._amortization_counter % self._amortization_freq == 0:
343+
self._weight_sum = np.sum(self.weight)
344+
self._amortization_counter = 0

tianshou/policy/modelfree/dqn.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from copy import deepcopy
44
import torch.nn.functional as F
55

6-
from tianshou.data import Batch
6+
from tianshou.data import Batch, PrioritizedReplayBuffer
77
from tianshou.policy import BasePolicy
88

99

@@ -98,6 +98,18 @@ def process_fn(self, batch, buffer, indice):
9898
target_q[gammas != self._n_step] = 0
9999
returns += (self._gamma ** gammas) * target_q
100100
batch.returns = returns
101+
if isinstance(buffer, PrioritizedReplayBuffer):
102+
q = self(batch).logits
103+
q = q[np.arange(len(q)), batch.act]
104+
r = batch.returns
105+
if isinstance(r, np.ndarray):
106+
r = torch.tensor(r, device=q.device, dtype=q.dtype)
107+
td = r-q
108+
buffer.update_weight(indice, td.detach().numpy())
109+
impt_weight = torch.tensor(batch.impt_weight,
110+
device=q.device, dtype=torch.float)
111+
loss = (td.pow(2)*impt_weight).mean()
112+
batch.loss = loss
101113
return batch
102114

103115
def forward(self, batch, state=None,
@@ -133,12 +145,15 @@ def learn(self, batch, **kwargs):
133145
if self._target and self._cnt % self._freq == 0:
134146
self.sync_weight()
135147
self.optim.zero_grad()
136-
q = self(batch).logits
137-
q = q[np.arange(len(q)), batch.act]
138-
r = batch.returns
139-
if isinstance(r, np.ndarray):
140-
r = torch.tensor(r, device=q.device, dtype=q.dtype)
141-
loss = F.mse_loss(q, r)
148+
if hasattr(batch, 'loss'):
149+
loss = batch.loss
150+
else:
151+
q = self(batch).logits
152+
q = q[np.arange(len(q)), batch.act]
153+
r = batch.returns
154+
if isinstance(r, np.ndarray):
155+
r = torch.tensor(r, device=q.device, dtype=q.dtype)
156+
loss = F.mse_loss(q, r)
142157
loss.backward()
143158
self.optim.step()
144159
self._cnt += 1

0 commit comments

Comments
 (0)