Skip to content

Commit a511cb4

Browse files
Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed. Co-authored-by: n+e <trinkle23897@gmail.com>
1 parent a633a6a commit a511cb4

File tree

26 files changed

+628
-80
lines changed

26 files changed

+628
-80
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,5 @@ MUJOCO_LOG.TXT
145145
*.zip
146146
*.pstats
147147
*.swp
148+
*.pkl
149+
*.hdf5

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
3232
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
3333
- Vanilla Imitation Learning
34+
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
3435
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
3536
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
3637
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

docs/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ Welcome to Tianshou!
2020
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
2121
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
2222
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
23-
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
2423
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
24+
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
25+
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
2526
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
2627
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
2728

docs/tutorials/concepts.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ Trainer
201201

202202
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
203203

204-
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
204+
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
205205

206206

207207
.. _pseudocode:

docs/tutorials/dqn.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n
120120
Train Policy with a Trainer
121121
---------------------------
122122

123-
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
123+
Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows:
124124
::
125125

126126
result = ts.trainer.offpolicy_trainer(
@@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
133133
writer=None)
134134
print(f'Finished training! Use {result["duration"]}')
135135

136-
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
136+
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
137137

138138
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
139139
* ``step_per_epoch``: The number of step for updating policy network in one epoch;

examples/atari/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
3838
| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
3939
| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |
4040

41-
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
41+
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
42+
43+
# BCQ
44+
45+
TODO: after the `done` issue fixed, the result should be re-tuned and place here.
46+
47+
To running BCQ algorithm on Atari, you need to do the following things:
48+
49+
- Train an expert, by using the command listed in the above DQN section;
50+
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
51+
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
52+

examples/atari/atari_bcq.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import os
2+
import torch
3+
import pickle
4+
import pprint
5+
import argparse
6+
import numpy as np
7+
from torch.utils.tensorboard import SummaryWriter
8+
9+
from tianshou.env import SubprocVectorEnv
10+
from tianshou.trainer import offline_trainer
11+
from tianshou.utils.net.discrete import Actor
12+
from tianshou.policy import DiscreteBCQPolicy
13+
from tianshou.data import Collector, ReplayBuffer
14+
15+
from atari_network import DQN
16+
from atari_wrapper import wrap_deepmind
17+
18+
19+
def get_args():
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
22+
parser.add_argument("--seed", type=int, default=1626)
23+
parser.add_argument("--eps-test", type=float, default=0.001)
24+
parser.add_argument("--lr", type=float, default=6.25e-5)
25+
parser.add_argument("--gamma", type=float, default=0.99)
26+
parser.add_argument("--n-step", type=int, default=3)
27+
parser.add_argument("--target-update-freq", type=int, default=8000)
28+
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
29+
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
30+
parser.add_argument("--epoch", type=int, default=100)
31+
parser.add_argument("--step-per-epoch", type=int, default=10000)
32+
parser.add_argument("--batch-size", type=int, default=32)
33+
parser.add_argument('--hidden-sizes', type=int,
34+
nargs='*', default=[512])
35+
parser.add_argument("--test-num", type=int, default=100)
36+
parser.add_argument('--frames_stack', type=int, default=4)
37+
parser.add_argument("--logdir", type=str, default="log")
38+
parser.add_argument("--render", type=float, default=0.)
39+
parser.add_argument("--resume-path", type=str, default=None)
40+
parser.add_argument("--watch", default=False, action="store_true",
41+
help="watch the play of pre-trained policy only")
42+
parser.add_argument("--log-interval", type=int, default=1000)
43+
parser.add_argument(
44+
"--load-buffer-name", type=str,
45+
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
46+
)
47+
parser.add_argument(
48+
"--device", type=str,
49+
default="cuda" if torch.cuda.is_available() else "cpu",
50+
)
51+
args = parser.parse_known_args()[0]
52+
return args
53+
54+
55+
def make_atari_env(args):
56+
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
57+
58+
59+
def make_atari_env_watch(args):
60+
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
61+
episode_life=False, clip_rewards=False)
62+
63+
64+
def test_discrete_bcq(args=get_args()):
65+
# envs
66+
env = make_atari_env(args)
67+
args.state_shape = env.observation_space.shape or env.observation_space.n
68+
args.action_shape = env.action_space.shape or env.action_space.n
69+
# should be N_FRAMES x H x W
70+
print("Observations shape:", args.state_shape)
71+
print("Actions shape:", args.action_shape)
72+
# make environments
73+
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
74+
for _ in range(args.test_num)])
75+
# seed
76+
np.random.seed(args.seed)
77+
torch.manual_seed(args.seed)
78+
test_envs.seed(args.seed)
79+
# model
80+
feature_net = DQN(*args.state_shape, args.action_shape,
81+
device=args.device, features_only=True).to(args.device)
82+
policy_net = Actor(feature_net, args.action_shape,
83+
hidden_sizes=args.hidden_sizes).to(args.device)
84+
imitation_net = Actor(feature_net, args.action_shape,
85+
hidden_sizes=args.hidden_sizes).to(args.device)
86+
optim = torch.optim.Adam(
87+
set(policy_net.parameters()).union(imitation_net.parameters()),
88+
lr=args.lr,
89+
)
90+
# define policy
91+
policy = DiscreteBCQPolicy(
92+
policy_net, imitation_net, optim, args.gamma, args.n_step,
93+
args.target_update_freq, args.eps_test,
94+
args.unlikely_action_threshold, args.imitation_logits_penalty,
95+
)
96+
# load a previous policy
97+
if args.resume_path:
98+
policy.load_state_dict(torch.load(
99+
args.resume_path, map_location=args.device
100+
))
101+
print("Loaded agent from: ", args.resume_path)
102+
# buffer
103+
assert os.path.exists(args.load_buffer_name), \
104+
"Please run atari_dqn.py first to get expert's data buffer."
105+
if args.load_buffer_name.endswith('.pkl'):
106+
buffer = pickle.load(open(args.load_buffer_name, "rb"))
107+
elif args.load_buffer_name.endswith('.hdf5'):
108+
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
109+
else:
110+
print(f"Unknown buffer format: {args.load_buffer_name}")
111+
exit(0)
112+
113+
# collector
114+
test_collector = Collector(policy, test_envs)
115+
116+
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
117+
writer = SummaryWriter(log_path)
118+
119+
def save_fn(policy):
120+
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
121+
122+
def stop_fn(mean_rewards):
123+
return False
124+
125+
# watch agent's performance
126+
def watch():
127+
print("Setup test envs ...")
128+
policy.eval()
129+
policy.set_eps(args.eps_test)
130+
test_envs.seed(args.seed)
131+
print("Testing agent ...")
132+
test_collector.reset()
133+
result = test_collector.collect(n_episode=[1] * args.test_num,
134+
render=args.render)
135+
pprint.pprint(result)
136+
137+
if args.watch:
138+
watch()
139+
exit(0)
140+
141+
result = offline_trainer(
142+
policy, buffer, test_collector,
143+
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
144+
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
145+
log_interval=args.log_interval,
146+
)
147+
148+
pprint.pprint(result)
149+
watch()
150+
151+
152+
if __name__ == "__main__":
153+
test_discrete_bcq(get_args())

examples/atari/atari_dqn.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_args():
4141
parser.add_argument('--resume-path', type=str, default=None)
4242
parser.add_argument('--watch', default=False, action='store_true',
4343
help='watch the play of pre-trained policy only')
44+
parser.add_argument('--save-buffer-name', type=str, default=None)
4445
return parser.parse_args()
4546

4647

@@ -120,13 +121,25 @@ def test_fn(epoch, env_step):
120121

121122
# watch agent's performance
122123
def watch():
123-
print("Testing agent ...")
124+
print("Setup test envs ...")
124125
policy.eval()
125126
policy.set_eps(args.eps_test)
126127
test_envs.seed(args.seed)
127-
test_collector.reset()
128-
result = test_collector.collect(n_episode=[1] * args.test_num,
129-
render=args.render)
128+
if args.save_buffer_name:
129+
print(f"Generate buffer with size {args.buffer_size}")
130+
buffer = ReplayBuffer(
131+
args.buffer_size, ignore_obs_next=True,
132+
save_only_last_obs=True, stack_num=args.frames_stack)
133+
collector = Collector(policy, test_envs, buffer)
134+
result = collector.collect(n_step=args.buffer_size)
135+
print(f"Save buffer into {args.save_buffer_name}")
136+
# Unfortunately, pickle will cause oom with 1M buffer size
137+
buffer.save_hdf5(args.save_buffer_name)
138+
else:
139+
print("Testing agent ...")
140+
test_collector.reset()
141+
result = test_collector.collect(n_episode=[1] * args.test_num,
142+
render=args.render)
130143
pprint.pprint(result)
131144

132145
if args.watch:

examples/atari/atari_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def forward(
4444
info: Dict[str, Any] = {},
4545
) -> Tuple[torch.Tensor, Any]:
4646
r"""Mapping: x -> Q(x, \*)."""
47-
x = torch.as_tensor(
48-
x, device=self.device, dtype=torch.float32) # type: ignore
47+
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
4948
return self.net(x), state
5049

5150

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_version() -> str:
5151
"tensorboard",
5252
"torch>=1.4.0",
5353
"numba>=0.51.0",
54-
"h5py>=3.1.0"
54+
"h5py>=2.10.0", # to match tensorflow's minimal requirements
5555
],
5656
extras_require={
5757
"dev": [

0 commit comments

Comments
 (0)