|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import pickle |
| 4 | +import pprint |
| 5 | +import argparse |
| 6 | +import numpy as np |
| 7 | +from torch.utils.tensorboard import SummaryWriter |
| 8 | + |
| 9 | +from tianshou.env import SubprocVectorEnv |
| 10 | +from tianshou.trainer import offline_trainer |
| 11 | +from tianshou.utils.net.discrete import Actor |
| 12 | +from tianshou.policy import DiscreteBCQPolicy |
| 13 | +from tianshou.data import Collector, ReplayBuffer |
| 14 | + |
| 15 | +from atari_network import DQN |
| 16 | +from atari_wrapper import wrap_deepmind |
| 17 | + |
| 18 | + |
| 19 | +def get_args(): |
| 20 | + parser = argparse.ArgumentParser() |
| 21 | + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") |
| 22 | + parser.add_argument("--seed", type=int, default=1626) |
| 23 | + parser.add_argument("--eps-test", type=float, default=0.001) |
| 24 | + parser.add_argument("--lr", type=float, default=6.25e-5) |
| 25 | + parser.add_argument("--gamma", type=float, default=0.99) |
| 26 | + parser.add_argument("--n-step", type=int, default=3) |
| 27 | + parser.add_argument("--target-update-freq", type=int, default=8000) |
| 28 | + parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) |
| 29 | + parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) |
| 30 | + parser.add_argument("--epoch", type=int, default=100) |
| 31 | + parser.add_argument("--step-per-epoch", type=int, default=10000) |
| 32 | + parser.add_argument("--batch-size", type=int, default=32) |
| 33 | + parser.add_argument('--hidden-sizes', type=int, |
| 34 | + nargs='*', default=[512]) |
| 35 | + parser.add_argument("--test-num", type=int, default=100) |
| 36 | + parser.add_argument('--frames_stack', type=int, default=4) |
| 37 | + parser.add_argument("--logdir", type=str, default="log") |
| 38 | + parser.add_argument("--render", type=float, default=0.) |
| 39 | + parser.add_argument("--resume-path", type=str, default=None) |
| 40 | + parser.add_argument("--watch", default=False, action="store_true", |
| 41 | + help="watch the play of pre-trained policy only") |
| 42 | + parser.add_argument("--log-interval", type=int, default=1000) |
| 43 | + parser.add_argument( |
| 44 | + "--load-buffer-name", type=str, |
| 45 | + default="./expert_DQN_PongNoFrameskip-v4.hdf5", |
| 46 | + ) |
| 47 | + parser.add_argument( |
| 48 | + "--device", type=str, |
| 49 | + default="cuda" if torch.cuda.is_available() else "cpu", |
| 50 | + ) |
| 51 | + args = parser.parse_known_args()[0] |
| 52 | + return args |
| 53 | + |
| 54 | + |
| 55 | +def make_atari_env(args): |
| 56 | + return wrap_deepmind(args.task, frame_stack=args.frames_stack) |
| 57 | + |
| 58 | + |
| 59 | +def make_atari_env_watch(args): |
| 60 | + return wrap_deepmind(args.task, frame_stack=args.frames_stack, |
| 61 | + episode_life=False, clip_rewards=False) |
| 62 | + |
| 63 | + |
| 64 | +def test_discrete_bcq(args=get_args()): |
| 65 | + # envs |
| 66 | + env = make_atari_env(args) |
| 67 | + args.state_shape = env.observation_space.shape or env.observation_space.n |
| 68 | + args.action_shape = env.action_space.shape or env.action_space.n |
| 69 | + # should be N_FRAMES x H x W |
| 70 | + print("Observations shape:", args.state_shape) |
| 71 | + print("Actions shape:", args.action_shape) |
| 72 | + # make environments |
| 73 | + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) |
| 74 | + for _ in range(args.test_num)]) |
| 75 | + # seed |
| 76 | + np.random.seed(args.seed) |
| 77 | + torch.manual_seed(args.seed) |
| 78 | + test_envs.seed(args.seed) |
| 79 | + # model |
| 80 | + feature_net = DQN(*args.state_shape, args.action_shape, |
| 81 | + device=args.device, features_only=True).to(args.device) |
| 82 | + policy_net = Actor(feature_net, args.action_shape, |
| 83 | + hidden_sizes=args.hidden_sizes).to(args.device) |
| 84 | + imitation_net = Actor(feature_net, args.action_shape, |
| 85 | + hidden_sizes=args.hidden_sizes).to(args.device) |
| 86 | + optim = torch.optim.Adam( |
| 87 | + set(policy_net.parameters()).union(imitation_net.parameters()), |
| 88 | + lr=args.lr, |
| 89 | + ) |
| 90 | + # define policy |
| 91 | + policy = DiscreteBCQPolicy( |
| 92 | + policy_net, imitation_net, optim, args.gamma, args.n_step, |
| 93 | + args.target_update_freq, args.eps_test, |
| 94 | + args.unlikely_action_threshold, args.imitation_logits_penalty, |
| 95 | + ) |
| 96 | + # load a previous policy |
| 97 | + if args.resume_path: |
| 98 | + policy.load_state_dict(torch.load( |
| 99 | + args.resume_path, map_location=args.device |
| 100 | + )) |
| 101 | + print("Loaded agent from: ", args.resume_path) |
| 102 | + # buffer |
| 103 | + assert os.path.exists(args.load_buffer_name), \ |
| 104 | + "Please run atari_dqn.py first to get expert's data buffer." |
| 105 | + if args.load_buffer_name.endswith('.pkl'): |
| 106 | + buffer = pickle.load(open(args.load_buffer_name, "rb")) |
| 107 | + elif args.load_buffer_name.endswith('.hdf5'): |
| 108 | + buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) |
| 109 | + else: |
| 110 | + print(f"Unknown buffer format: {args.load_buffer_name}") |
| 111 | + exit(0) |
| 112 | + |
| 113 | + # collector |
| 114 | + test_collector = Collector(policy, test_envs) |
| 115 | + |
| 116 | + log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') |
| 117 | + writer = SummaryWriter(log_path) |
| 118 | + |
| 119 | + def save_fn(policy): |
| 120 | + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) |
| 121 | + |
| 122 | + def stop_fn(mean_rewards): |
| 123 | + return False |
| 124 | + |
| 125 | + # watch agent's performance |
| 126 | + def watch(): |
| 127 | + print("Setup test envs ...") |
| 128 | + policy.eval() |
| 129 | + policy.set_eps(args.eps_test) |
| 130 | + test_envs.seed(args.seed) |
| 131 | + print("Testing agent ...") |
| 132 | + test_collector.reset() |
| 133 | + result = test_collector.collect(n_episode=[1] * args.test_num, |
| 134 | + render=args.render) |
| 135 | + pprint.pprint(result) |
| 136 | + |
| 137 | + if args.watch: |
| 138 | + watch() |
| 139 | + exit(0) |
| 140 | + |
| 141 | + result = offline_trainer( |
| 142 | + policy, buffer, test_collector, |
| 143 | + args.epoch, args.step_per_epoch, args.test_num, args.batch_size, |
| 144 | + stop_fn=stop_fn, save_fn=save_fn, writer=writer, |
| 145 | + log_interval=args.log_interval, |
| 146 | + ) |
| 147 | + |
| 148 | + pprint.pprint(result) |
| 149 | + watch() |
| 150 | + |
| 151 | + |
| 152 | +if __name__ == "__main__": |
| 153 | + test_discrete_bcq(get_args()) |
0 commit comments