-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
106 lines (77 loc) · 2.73 KB
/
train.py
File metadata and controls
106 lines (77 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Runs pong.
"""
import argparse
import gym
import numpy as np
import torch
from torchtyping import TensorType
from tqdm import tqdm
from policy import Policy
from preprocess import preprocess
env = gym.make("ALE/Pong-v5", full_action_space=False)
obs_dim = preprocess(np.zeros((env.observation_space.shape))).numel()
act_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def collect_trajectory(policy: Policy):
losses = 0
wins = 0
terminated, truncated = False, False
observations = []
actions = []
rewards = []
observation, info = env.reset()
prev_observation = preprocess(observation)
observation = preprocess(observation)
while not terminated and not truncated:
diff = (observation - prev_observation).to(device=device)
action = policy.act(diff.flatten())
env_action = [2, 3][action]
observations.append(diff)
actions.append(action)
prev_observation = observation
observation, reward, terminated, truncated, _ = env.step(env_action)
observation = preprocess(observation)
if reward > 0:
wins += 1
if reward < 0:
losses += 1
rewards.append(reward)
print(f"Wins: {wins}, Losses: {losses}")
return observations, actions, rewards
def collect_trajectories(policy: Policy, N: int = 10) -> tuple[
TensorType["N", "T", "H", "W", "C"],
TensorType["N", "T"],
TensorType["N", "T"],
]:
"""
Collects `N` trajectories from the environment under the specified policy.
Returns tensor of shape (N, 3, H), where
- N is the number of trajectories
- 3 is the number of elements in the tuple (states, actions, rewards)
- H is the length of the trajectory
"""
trajectories = [collect_trajectory(policy) for _ in tqdm(range(N))]
observations = [trajectory[0] for trajectory in trajectories]
actions = [trajectory[1] for trajectory in trajectories]
rewards = [trajectory[2] for trajectory in trajectories]
return observations, actions, rewards
def train(policy: Policy, epochs: int = 20):
for epoch in range(epochs):
print("Epoch", epoch)
trajectories = collect_trajectories(policy)
policy.train(trajectories)
policy.save("model.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load", action="store_true")
parser.add_argument("--epochs", type=int, default=20, required=False)
args = parser.parse_args()
policy = Policy(obs_dim, act_dim)
policy.model.to(device)
if args.load:
print("Loading model...")
policy.load("model.pt")
print(f"Training for {args.epochs} epochs...")
train(policy, args.epochs)
env.close()