|
1 | 1 | import numpy as np |
2 | | -from tianshou.data import ReplayBuffer |
| 2 | +from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer |
3 | 3 |
|
4 | 4 | if __name__ == '__main__': |
5 | 5 | from env import MyTestEnv |
@@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4): |
47 | 47 | print(buf) |
48 | 48 |
|
49 | 49 |
|
| 50 | +def test_priortized_replaybuffer(size=32, bufsize=15): |
| 51 | + env = MyTestEnv(size) |
| 52 | + buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) |
| 53 | + obs = env.reset() |
| 54 | + action_list = [1] * 5 + [0] * 10 + [1] * 10 |
| 55 | + for i, a in enumerate(action_list): |
| 56 | + obs_next, rew, done, info = env.step(a) |
| 57 | + buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) |
| 58 | + obs = obs_next |
| 59 | + assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]), |
| 60 | + 1, rtol=1e-12) |
| 61 | + data, indice = buf.sample(len(buf) // 2) |
| 62 | + if len(buf) // 2 == 0: |
| 63 | + assert len(data) == len(buf) |
| 64 | + else: |
| 65 | + assert len(data) == len(buf) // 2 |
| 66 | + assert len(buf) == min(bufsize, i + 1), print(len(buf), i) |
| 67 | + assert np.isclose(buf._weight_sum, (buf.weight).sum()) |
| 68 | + data, indice = buf.sample(len(buf) // 2) |
| 69 | + buf.update_weight(indice, -data.weight / 2) |
| 70 | + assert np.isclose(buf.weight[indice], np.power( |
| 71 | + np.abs(-data.weight / 2), buf._alpha)).all() |
| 72 | + assert np.isclose(buf._weight_sum, (buf.weight).sum()) |
| 73 | + |
| 74 | + |
50 | 75 | if __name__ == '__main__': |
51 | 76 | test_replaybuffer() |
52 | 77 | test_stack() |
| 78 | + test_priortized_replaybuffer(233333, 200000) |
0 commit comments