Skip to content

Commit 6b96f12

Browse files
committed
fix pdqn
1 parent b237494 commit 6b96f12

File tree

8 files changed

+45
-53
lines changed

8 files changed

+45
-53
lines changed

.github/ISSUE_TEMPLATE.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
+ [ ] documentation request (i.e. "X is missing from the documentation.")
55
+ [ ] new feature request
66
- [ ] I have visited the [source website], and in particular read the [known issues]
7-
- [ ] I have searched through the [issue tracker] for duplicates
7+
- [ ] I have searched through the [issue categories] for duplicates
88
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
99
```python
10-
import tianshou, sys
11-
print(tianshou.__version__, sys.version, sys.platform)
10+
import tianshou, torch, sys
11+
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
1212
```
1313

1414
[source website]: https://github.com/thu-ml/tianshou/
1515
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
16-
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
16+
[issue categories]: https://github.com/thu-ml/tianshou/projects/2

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
Less important but also useful:
99

1010
- [ ] I have visited the [source website], and in particular read the [known issues]
11-
- [ ] I have searched through the [issue tracker] for duplicates
11+
- [ ] I have searched through the [issue categories] for duplicates
1212
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
1313
```python
14-
import tianshou, sys
15-
print(tianshou.__version__, sys.version, sys.platform)
14+
import tianshou, torch, sys
15+
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
1616
```
1717

1818
[source website]: https://github.com/thu-ml/tianshou
1919
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
20-
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
20+
[issue categories]: https://github.com/thu-ml/tianshou/projects/2

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
2121
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
2222
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
23-
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
23+
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
2424
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
2525
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
2626
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Welcome to Tianshou!
1111
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
1212
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
1313
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
14-
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
14+
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
1515
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
1616
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
1717
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_

test/base/test_buffer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from tianshou.data import ReplayBuffer
2+
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
33

44
if __name__ == '__main__':
55
from env import MyTestEnv
@@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4):
4747
print(buf)
4848

4949

50+
def test_priortized_replaybuffer(size=32, bufsize=15):
51+
env = MyTestEnv(size)
52+
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
53+
obs = env.reset()
54+
action_list = [1] * 5 + [0] * 10 + [1] * 10
55+
for i, a in enumerate(action_list):
56+
obs_next, rew, done, info = env.step(a)
57+
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
58+
obs = obs_next
59+
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
60+
1, rtol=1e-12)
61+
data, indice = buf.sample(len(buf) // 2)
62+
if len(buf) // 2 == 0:
63+
assert len(data) == len(buf)
64+
else:
65+
assert len(data) == len(buf) // 2
66+
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
67+
assert np.isclose(buf._weight_sum, (buf.weight).sum())
68+
data, indice = buf.sample(len(buf) // 2)
69+
buf.update_weight(indice, -data.weight / 2)
70+
assert np.isclose(buf.weight[indice], np.power(
71+
np.abs(-data.weight / 2), buf._alpha)).all()
72+
assert np.isclose(buf._weight_sum, (buf.weight).sum())
73+
74+
5075
if __name__ == '__main__':
5176
test_replaybuffer()
5277
test_stack()
78+
test_priortized_replaybuffer(233333, 200000)

test/base/test_prioritized_replay_buffer.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tianshou/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from tianshou import data, env, utils, policy, trainer, \
22
exploration
33

4-
__version__ = '0.2.1'
4+
__version__ = '0.2.2'
55
__all__ = [
66
'env',
77
'data',

tianshou/policy/modelfree/dqn.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,15 @@ def process_fn(self, batch, buffer, indice):
104104
r = batch.returns
105105
if isinstance(r, np.ndarray):
106106
r = torch.tensor(r, device=q.device, dtype=q.dtype)
107-
td = r-q
108-
buffer.update_weight(indice, td.detach().numpy())
107+
td = r - q
108+
buffer.update_weight(indice, td.detach().cpu().numpy())
109109
impt_weight = torch.tensor(batch.impt_weight,
110110
device=q.device, dtype=torch.float)
111-
loss = (td.pow(2)*impt_weight).mean()
112-
batch.loss = loss
111+
loss = (td.pow(2) * impt_weight).mean()
112+
if not hasattr(batch, 'loss'):
113+
batch.loss = loss
114+
else:
115+
batch.loss += loss
113116
return batch
114117

115118
def forward(self, batch, state=None,

0 commit comments

Comments
 (0)