-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpolicy.py
More file actions
100 lines (78 loc) · 2.88 KB
/
policy.py
File metadata and controls
100 lines (78 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Vanilla Policy Gradient
"""
import torch
from torchtyping import TensorType
def reward_to_go(rewards: list[float], discount_factor=0.99) -> torch.Tensor:
out = rewards[:]
for j in range(len(rewards) - 2, -1, -1):
if out[j] == 0:
out[j] = out[j + 1] * discount_factor
return torch.tensor(out)
def baseline(trajectory_rewards: TensorType["N", "H"]):
# TODO: Add baseline
return 0
def compute_advantage(trajectory_rewards: list[float]) -> torch.Tensor:
""" """
out = reward_to_go(trajectory_rewards) - baseline(trajectory_rewards)
# Normalize rewards
out = (out - out.mean()) / (out.std() + 1e-8)
return out
class Model(torch.nn.Module):
def __init__(self, obs_dim, act_dim):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(obs_dim, 256)
self.fc2 = torch.nn.Linear(256, act_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
logits = self.fc2(x)
return logits
class Policy:
def __init__(self, obs_dim, act_dim):
self.model = Model(obs_dim, act_dim)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
def _compute_loss(
self,
trajectories: tuple[
TensorType["N", "T", "H", "W", "C"],
TensorType["N", "T"],
TensorType["N", "T"],
],
) -> torch.Tensor:
# n_observations: (N, H, height, width) = (N, H, 80, 80)
# n_actions, n_rewards: (N, H)
n_observations, n_actions, n_rewards = trajectories
device = n_observations[0][0].device
terms = torch.zeros((len(n_observations),)).to(device=device)
for i, (observations, actions, rewards) in enumerate(
zip(n_observations, n_actions, n_rewards)
):
advantages = compute_advantage(rewards).to(device=device)
observations = torch.stack([obs.flatten() for obs in observations]).to(
device=device
)
terms[i] = (
self.log_prob(observations, torch.tensor(actions).to(device=device))
* advantages.to(device=device)
).sum()
loss = -terms.mean()
return loss
def _dist(self, obs):
return torch.distributions.Categorical(logits=self.model(obs))
def train(self, trajectories):
loss = self._compute_loss(trajectories)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
print("Loss:", loss.item())
def log_prob(self, obs, act):
"""
Return probability of taking an action given an observation.
"""
return self._dist(obs).log_prob(act)
def act(self, obs):
return self._dist(obs).sample().item()
def save(self, path):
torch.save(self.model.state_dict(), path)
def load(self, path):
self.model.load_state_dict(torch.load(path))