Skip to content

Commit 97df511

Browse files
authored
Add VizDoom PPO example and results (#533)
* update vizdoom ppo example * update README with results
1 parent 23fbc3b commit 97df511

File tree

8 files changed

+102
-32
lines changed

8 files changed

+102
-32
lines changed

examples/vizdoom/README.md

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,35 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
5353

5454
See [maps/README.md](maps/README.md)
5555

56-
## Algorithms
57-
58-
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
59-
6056
## Reward
6157

6258
1. living reward is bad
6359
2. combo-action is really important
6460
3. negative reward for health and ammo2 is really helpful for d3/d4
6561
4. only with positive reward for health is really helpful for d1
6662
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower
63+
64+
## Algorithms
65+
66+
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
67+
68+
### C51 (single run)
69+
70+
| task | best reward | reward curve | parameters |
71+
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
72+
| D2_navigation | 747.52 | ![](results/c51/D2_navigation_rew.png) | `python3 vizdoom_c51.py --task "D2_navigation"` |
73+
| D3_battle | 1855.29 | ![](results/c51/D3_battle_rew.png) | `python3 vizdoom_c51.py --task "D3_battle"` |
74+
75+
### PPO (single run)
76+
77+
| task | best reward | reward curve | parameters |
78+
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
79+
| D2_navigation | 770.75 | ![](results/ppo/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation"` |
80+
| D3_battle | 320.59 | ![](results/ppo/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle"` |
81+
82+
### PPO with ICM (single run)
83+
84+
| task | best reward | reward curve | parameters |
85+
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
86+
| D2_navigation | 844.99 | ![](results/ppo_icm/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` |
87+
| D3_battle | 547.08 | ![](results/ppo_icm/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` |
209 KB
Loading
156 KB
Loading
159 KB
Loading
165 KB
Loading
157 KB
Loading
159 KB
Loading
Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import torch
77
from env import Env
88
from network import DQN
9+
from torch.optim.lr_scheduler import LambdaLR
910
from torch.utils.tensorboard import SummaryWriter
1011

1112
from tianshou.data import Collector, VectorReplayBuffer
1213
from tianshou.env import ShmemVectorEnv
13-
from tianshou.policy import A2CPolicy, ICMPolicy
14+
from tianshou.policy import ICMPolicy, PPOPolicy
1415
from tianshou.trainer import onpolicy_trainer
1516
from tianshou.utils import TensorboardLogger
1617
from tianshou.utils.net.common import ActorCritic
@@ -21,18 +22,28 @@ def get_args():
2122
parser = argparse.ArgumentParser()
2223
parser.add_argument('--task', type=str, default='D2_navigation')
2324
parser.add_argument('--seed', type=int, default=0)
24-
parser.add_argument('--buffer-size', type=int, default=2000000)
25-
parser.add_argument('--lr', type=float, default=0.0001)
25+
parser.add_argument('--buffer-size', type=int, default=100000)
26+
parser.add_argument('--lr', type=float, default=0.00002)
2627
parser.add_argument('--gamma', type=float, default=0.99)
2728
parser.add_argument('--epoch', type=int, default=300)
2829
parser.add_argument('--step-per-epoch', type=int, default=100000)
29-
parser.add_argument('--episode-per-collect', type=int, default=10)
30-
parser.add_argument('--update-per-step', type=float, default=0.1)
31-
parser.add_argument('--update-per-step', type=int, default=1)
32-
parser.add_argument('--batch-size', type=int, default=64)
33-
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
30+
parser.add_argument('--step-per-collect', type=int, default=1000)
31+
parser.add_argument('--repeat-per-collect', type=int, default=4)
32+
parser.add_argument('--batch-size', type=int, default=256)
33+
parser.add_argument('--hidden-size', type=int, default=512)
3434
parser.add_argument('--training-num', type=int, default=10)
3535
parser.add_argument('--test-num', type=int, default=100)
36+
parser.add_argument('--rew-norm', type=int, default=False)
37+
parser.add_argument('--vf-coef', type=float, default=0.5)
38+
parser.add_argument('--ent-coef', type=float, default=0.01)
39+
parser.add_argument('--gae-lambda', type=float, default=0.95)
40+
parser.add_argument('--lr-decay', type=int, default=True)
41+
parser.add_argument('--max-grad-norm', type=float, default=0.5)
42+
parser.add_argument('--eps-clip', type=float, default=0.2)
43+
parser.add_argument('--dual-clip', type=float, default=None)
44+
parser.add_argument('--value-clip', type=int, default=0)
45+
parser.add_argument('--norm-adv', type=int, default=1)
46+
parser.add_argument('--recompute-adv', type=int, default=0)
3647
parser.add_argument('--logdir', type=str, default='log')
3748
parser.add_argument('--render', type=float, default=0.)
3849
parser.add_argument(
@@ -75,7 +86,7 @@ def get_args():
7586
return parser.parse_args()
7687

7788

78-
def test_a2c(args=get_args()):
89+
def test_ppo(args=get_args()):
7990
args.cfg_path = f"maps/{args.task}.cfg"
8091
args.wad_path = f"maps/{args.task}.wad"
8192
args.res = (args.skip_num, 84, 84)
@@ -105,33 +116,65 @@ def test_a2c(args=get_args()):
105116
test_envs.seed(args.seed)
106117
# define model
107118
net = DQN(
108-
*args.state_shape, args.action_shape, device=args.device, features_only=True
119+
*args.state_shape,
120+
args.action_shape,
121+
device=args.device,
122+
features_only=True,
123+
output_dim=args.hidden_size
109124
)
110-
actor = Actor(
111-
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
112-
)
113-
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
125+
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
126+
critic = Critic(net, device=args.device)
114127
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
128+
129+
lr_scheduler = None
130+
if args.lr_decay:
131+
# decay learning rate to 0 linearly
132+
max_update_num = np.ceil(
133+
args.step_per_epoch / args.step_per_collect
134+
) * args.epoch
135+
136+
lr_scheduler = LambdaLR(
137+
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
138+
)
139+
115140
# define policy
116-
dist = torch.distributions.Categorical
117-
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
141+
def dist(p):
142+
return torch.distributions.Categorical(logits=p)
143+
144+
policy = PPOPolicy(
145+
actor,
146+
critic,
147+
optim,
148+
dist,
149+
discount_factor=args.gamma,
150+
gae_lambda=args.gae_lambda,
151+
max_grad_norm=args.max_grad_norm,
152+
vf_coef=args.vf_coef,
153+
ent_coef=args.ent_coef,
154+
reward_normalization=args.rew_norm,
155+
action_scaling=False,
156+
lr_scheduler=lr_scheduler,
157+
action_space=env.action_space,
158+
eps_clip=args.eps_clip,
159+
value_clip=args.value_clip,
160+
dual_clip=args.dual_clip,
161+
advantage_normalization=args.norm_adv,
162+
recompute_advantage=args.recompute_adv
163+
).to(args.device)
118164
if args.icm_lr_scale > 0:
119165
feature_net = DQN(
120166
*args.state_shape,
121167
args.action_shape,
122168
device=args.device,
123-
features_only=True
169+
features_only=True,
170+
output_dim=args.hidden_size
124171
)
125172
action_dim = np.prod(args.action_shape)
126173
feature_dim = feature_net.output_dim
127174
icm_net = IntrinsicCuriosityModule(
128-
feature_net.net,
129-
feature_dim,
130-
action_dim,
131-
hidden_sizes=args.hidden_sizes,
132-
device=args.device
175+
feature_net.net, feature_dim, action_dim, device=args.device
133176
)
134-
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
177+
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
135178
policy = ICMPolicy(
136179
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
137180
args.icm_forward_loss_weight
@@ -153,7 +196,8 @@ def test_a2c(args=get_args()):
153196
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
154197
test_collector = Collector(policy, test_envs, exploration_noise=True)
155198
# log
156-
log_path = os.path.join(args.logdir, args.task, 'a2c')
199+
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
200+
log_path = os.path.join(args.logdir, args.task, log_name)
157201
writer = SummaryWriter(log_path)
158202
writer.add_text("args", str(args))
159203
logger = TensorboardLogger(writer)
@@ -162,10 +206,15 @@ def save_fn(policy):
162206
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
163207

164208
def stop_fn(mean_rewards):
165-
return False
209+
if env.spec.reward_threshold:
210+
return mean_rewards >= env.spec.reward_threshold
211+
elif 'Pong' in args.task:
212+
return mean_rewards >= 20
213+
else:
214+
return False
166215

216+
# watch agent's performance
167217
def watch():
168-
# watch agent's performance
169218
print("Setup test envs ...")
170219
policy.eval()
171220
test_envs.seed(args.seed)
@@ -210,7 +259,7 @@ def watch():
210259
args.repeat_per_collect,
211260
args.test_num,
212261
args.batch_size,
213-
episode_per_collect=args.episode_per_collect,
262+
step_per_collect=args.step_per_collect,
214263
stop_fn=stop_fn,
215264
save_fn=save_fn,
216265
logger=logger,
@@ -222,4 +271,4 @@ def watch():
222271

223272

224273
if __name__ == '__main__':
225-
test_a2c(get_args())
274+
test_ppo(get_args())

0 commit comments

Comments
 (0)