Skip to content

Commit 4d63716

Browse files
authored
Merge pull request #80 from cpnota/release/0.2.2
Release/0.2.2
2 parents 341e525 + dfef2a5 commit 4d63716

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1342
-487
lines changed

all/agents/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
from .abstract import Agent
1+
from ._agent import Agent
22
from .a2c import A2C
3-
from .actor_critic import ActorCritic
43
from .ddpg import DDPG
54
from .dqn import DQN
6-
from .sarsa import Sarsa
5+
from .ppo import PPO
6+
from .vac import VAC
77
from .vpg import VPG
8+
from .vqn import VQN
9+
from .vsarsa import VSarsa
810

911
__all__ = [
1012
"Agent",
1113
"A2C",
12-
"ActorCritic",
1314
"DDPG",
1415
"DQN",
15-
"Sarsa",
16+
"PPO",
17+
"VAC",
1618
"VPG",
19+
"VQN",
20+
"VSarsa",
1721
]
File renamed without changes.

all/agents/a2c.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from all.environments import State
3-
from all.memory import NStepBatchBuffer
4-
from .abstract import Agent
3+
from all.memory import NStepAdvantageBuffer
4+
from ._agent import Agent
55

66

77
class A2C(Agent):
@@ -36,20 +36,17 @@ def act(self, states, rewards):
3636
def _train(self):
3737
if len(self._buffer) >= self._batch_size:
3838
states = State.from_list(self._features)
39-
_, _, returns, next_states, rollout_lengths = self._buffer.sample(self._batch_size)
40-
td_errors = (
41-
returns
42-
+ (self.discount_factor ** rollout_lengths)
43-
* self.v.eval(self.features.eval(next_states))
44-
- self.v(states)
45-
)
46-
self.v.reinforce(td_errors)
47-
self.policy.reinforce(td_errors)
39+
_, _, advantages = self._buffer.sample(self._batch_size)
40+
self.v(states)
41+
self.v.reinforce(advantages)
42+
self.policy.reinforce(advantages)
4843
self.features.reinforce()
4944
self._features = []
5045

5146
def _make_buffer(self):
52-
return NStepBatchBuffer(
47+
return NStepAdvantageBuffer(
48+
self.v,
49+
self.features,
5350
self.n_steps,
5451
self.n_envs,
5552
discount_factor=self.discount_factor

all/agents/actor_critic.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

all/agents/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from .abstract import Agent
2+
from ._agent import Agent
33

44
class DDPG(Agent):
55
def __init__(self,

all/agents/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from .abstract import Agent
2+
from ._agent import Agent
33

44

55
class DQN(Agent):

all/agents/ppo.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
from all.memory import GeneralizedAdvantageBuffer
3+
from ._agent import Agent
4+
5+
6+
class PPO(Agent):
7+
def __init__(
8+
self,
9+
features,
10+
v,
11+
policy,
12+
epsilon=0.2,
13+
epochs=4,
14+
minibatches=4,
15+
n_envs=None,
16+
n_steps=4,
17+
discount_factor=0.99,
18+
lam=0.95
19+
):
20+
if n_envs is None:
21+
raise RuntimeError("Must specify n_envs.")
22+
self.features = features
23+
self.v = v
24+
self.policy = policy
25+
self.n_envs = n_envs
26+
self.n_steps = n_steps
27+
self.discount_factor = discount_factor
28+
self.lam = lam
29+
self._epsilon = epsilon
30+
self._epochs = epochs
31+
self._batch_size = n_envs * n_steps
32+
self._minibatches = minibatches
33+
self._buffer = self._make_buffer()
34+
self._features = []
35+
36+
def act(self, states, rewards):
37+
self._train()
38+
actions = self.policy.eval(self.features.eval(states))
39+
self._buffer.store(states, actions, rewards)
40+
return actions
41+
42+
def _train(self):
43+
if len(self._buffer) >= self._batch_size:
44+
states, actions, advantages = self._buffer.sample(self._batch_size)
45+
with torch.no_grad():
46+
features = self.features.eval(states)
47+
pi_0 = self.policy.eval(features, actions)
48+
targets = self.v.eval(features) + advantages
49+
for _ in range(self._epochs):
50+
self._train_epoch(states, actions, pi_0, advantages, targets)
51+
52+
def _train_epoch(self, states, actions, pi_0, advantages, targets):
53+
minibatch_size = int(self._batch_size / self._minibatches)
54+
indexes = torch.randperm(self._batch_size)
55+
for n in range(self._minibatches):
56+
first = n * minibatch_size
57+
last = first + minibatch_size
58+
i = indexes[first:last]
59+
self._train_minibatch(states[i], actions[i], pi_0[i], advantages[i], targets[i])
60+
61+
def _train_minibatch(self, states, actions, pi_0, advantages, targets):
62+
features = self.features(states)
63+
self.policy(features, actions)
64+
self.policy.reinforce(self._compute_policy_loss(pi_0, advantages))
65+
self.v.reinforce(targets - self.v(features))
66+
self.features.reinforce()
67+
68+
def _compute_targets(self, returns, next_states, lengths):
69+
return (
70+
returns +
71+
(self.discount_factor ** lengths)
72+
* self.v.eval(self.features.eval(next_states))
73+
)
74+
75+
def _compute_policy_loss(self, pi_0, advantages):
76+
def _policy_loss(pi_i):
77+
ratios = torch.exp(pi_i - pi_0)
78+
surr1 = ratios * advantages
79+
surr2 = torch.clamp(ratios, 1.0 - self._epsilon, 1.0 + self._epsilon) * advantages
80+
return -torch.min(surr1, surr2).mean()
81+
return _policy_loss
82+
83+
def _make_buffer(self):
84+
return GeneralizedAdvantageBuffer(
85+
self.v,
86+
self.features,
87+
self.n_steps,
88+
self.n_envs,
89+
discount_factor=self.discount_factor,
90+
lam=self.lam
91+
)
92+

all/agents/vac.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ._agent import Agent
2+
3+
class VAC(Agent):
4+
'''Vanilla Actor-Critic'''
5+
def __init__(self, features, v, policy, gamma=1):
6+
self.features = features
7+
self.v = v
8+
self.policy = policy
9+
self.gamma = gamma
10+
self._previous_features = None
11+
12+
def act(self, state, reward):
13+
if self._previous_features:
14+
td_error = (
15+
reward
16+
+ self.gamma * self.v.eval(self.features.eval(state))
17+
- self.v(self._previous_features)
18+
)
19+
self.v.reinforce(td_error)
20+
self.policy.reinforce(td_error)
21+
self.features.reinforce()
22+
self._previous_features = self.features(state)
23+
return self.policy(self._previous_features)

all/agents/vpg.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
import torch
22
from all.environments import State
3-
from .abstract import Agent
3+
from ._agent import Agent
44

55
class VPG(Agent):
6+
'''Vanilla Policy Gradient'''
67
def __init__(
78
self,
89
features,
910
v,
1011
policy,
1112
gamma=0.99,
12-
n_episodes=1
13+
# run complete episodes until we have
14+
# seen at least min_batch_size states
15+
min_batch_size=1
1316
):
1417
self.features = features
1518
self.v = v
1619
self.policy = policy
1720
self.gamma = gamma
18-
self.n_episodes = n_episodes
21+
self.min_batch_size = min_batch_size
22+
self._current_batch_size = 0
1923
self._trajectories = []
2024
self._features = []
2125
self._rewards = []
@@ -43,10 +47,11 @@ def _terminal(self, reward):
4347
features = torch.cat(self._features)
4448
rewards = torch.tensor(self._rewards, device=features.device)
4549
self._trajectories.append((features, rewards))
50+
self._current_batch_size += len(features)
4651
self._features = []
4752
self._rewards = []
4853

49-
if len(self._trajectories) >= self.n_episodes:
54+
if self._current_batch_size >= self.min_batch_size:
5055
self._train()
5156

5257
def _train(self):
@@ -59,6 +64,7 @@ def _train(self):
5964
self.policy.reinforce(advantages)
6065
self.features.reinforce()
6166
self._trajectories = []
67+
self._current_batch_size = 0
6268

6369
def _compute_advantages(self, features, rewards):
6470
returns = self._compute_discounted_returns(rewards)

all/agents/vqn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
from ._agent import Agent
3+
4+
5+
class VQN(Agent):
6+
'''Vanilla Q-Network'''
7+
def __init__(self, q, policy, gamma=1):
8+
self.q = q
9+
self.policy = policy
10+
self.gamma = gamma
11+
self.env = None
12+
self.previous_state = None
13+
self.previous_action = None
14+
15+
def act(self, state, reward):
16+
action = self.policy(state)
17+
if self.previous_state:
18+
td_error = (
19+
reward
20+
+ self.gamma * torch.max(self.q.eval(state), dim=1)[0]
21+
- self.q(self.previous_state, self.previous_action)
22+
)
23+
self.q.reinforce(td_error)
24+
self.previous_state = state
25+
self.previous_action = action
26+
return action

0 commit comments

Comments
 (0)