Skip to content

Commit 9368f98

Browse files
authored
Merge pull request #66 from cpnota/release/0.2.0
Release/0.2.0
2 parents 021f0a0 + 1a4c477 commit 9368f98

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1027
-692
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ all.egg-info
1111
local
1212
legacy
1313
/runs
14+
/out

all/agents/a2c.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from all.environments import State
23
from all.memory import NStepBatchBuffer
34
from .abstract import Agent
45

@@ -23,30 +24,34 @@ def __init__(
2324
self.discount_factor = discount_factor
2425
self._batch_size = n_envs * n_steps
2526
self._buffer = self._make_buffer()
27+
self._features = []
2628

27-
def act(self, states, rewards, info=None):
28-
# store transition and train BEFORE choosing action
29-
# Do not need to know actions, so pass in empy array
29+
def act(self, states, rewards):
3030
self._buffer.store(states, torch.zeros(self.n_envs), rewards)
31-
while len(self._buffer) >= self._batch_size:
32-
self._train()
33-
return self.policy(self.features(states))
31+
self._train()
32+
features = self.features(states)
33+
self._features.append(features)
34+
return self.policy(features)
3435

3536
def _train(self):
36-
states, _, next_states, returns, rollout_lengths = self._buffer.sample(self._batch_size)
37-
td_errors = (
38-
returns
39-
+ (self.discount_factor ** rollout_lengths)
40-
* self.v.eval(self.features.eval(next_states))
41-
- self.v(self.features(states))
42-
)
43-
self.v.reinforce(td_errors, retain_graph=True)
44-
self.policy.reinforce(td_errors)
45-
self.features.reinforce()
37+
if len(self._buffer) >= self._batch_size:
38+
states = State.from_list(self._features)
39+
_, _, returns, next_states, rollout_lengths = self._buffer.sample(self._batch_size)
40+
td_errors = (
41+
returns
42+
+ (self.discount_factor ** rollout_lengths)
43+
* self.v.eval(self.features.eval(next_states))
44+
- self.v(states)
45+
)
46+
self.v.reinforce(td_errors)
47+
self.policy.reinforce(td_errors)
48+
self.features.reinforce()
49+
self._features = []
4650

4751
def _make_buffer(self):
4852
return NStepBatchBuffer(
4953
self.n_steps,
5054
self.n_envs,
5155
discount_factor=self.discount_factor
5256
)
57+

all/agents/abstract.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,8 @@ class Agent(ABC):
1111
An Agent implementation should encapsulate some particular reinforcement learning algorihthm.
1212
"""
1313

14-
def initial(self, state, info=None):
15-
"""
16-
Choose an action in the initial state of a new episode.
17-
18-
Reinforcement learning problems are often broken down into sequences called "episodes".
19-
An episode is a self-contained sequence of states, actions, and rewards.
20-
A "trial" consists of multiple episodes, and represents the lifetime of an agent.
21-
This method is called at the beginning of an episode.
22-
23-
Parameters
24-
----------
25-
state: The initial state of the new episode
26-
info (optional): The info object from the environment
27-
28-
Returns
29-
_______
30-
action: The action to take in the initial state
31-
"""
32-
3314
@abstractmethod
34-
def act(self, state, reward, info=None):
15+
def act(self, state, reward):
3516
"""
3617
Select an action for the current timestep and update internal parameters.
3718
@@ -53,22 +34,3 @@ def act(self, state, reward, info=None):
5334
_______
5435
action: The action to take at the current timestep
5536
"""
56-
57-
def terminal(self, reward, info=None):
58-
"""
59-
Accept the final reward of the episode and perform final updates.
60-
61-
After the final action is selected, it is still necessary to
62-
consider the reward given on the final timestep. This method
63-
provides a hook where the agent can examine this reward
64-
and perform any necessary updates.
65-
66-
Parameters
67-
----------
68-
reward: The reward from the previous timestep
69-
info (optional): The info object from the environment
70-
71-
Returns
72-
_______
73-
None
74-
"""

all/agents/actor_critic.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,14 @@ def __init__(self, v, policy, gamma=1):
77
self.gamma = gamma
88
self.previous_state = None
99

10-
def initial(self, state, info=None):
11-
self.previous_state = state
12-
return self.policy(state)
13-
14-
def act(self, state, reward, info=None):
15-
if self.previous_state is not None:
16-
td_error = reward + self.gamma * self.v.eval(state) - self.v(self.previous_state)
10+
def act(self, state, reward):
11+
if self.previous_state:
12+
td_error = (
13+
reward
14+
+ self.gamma * self.v.eval(state)
15+
- self.v(self.previous_state)
16+
)
1717
self.v.reinforce(td_error)
1818
self.policy.reinforce(td_error)
1919
self.previous_state = state
2020
return self.policy(state)
21-
22-
def terminal(self, reward, info=None):
23-
td_error = reward - self.v(self.previous_state)
24-
self.v.reinforce(td_error)
25-
self.policy.reinforce(td_error)

all/agents/dqn.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,44 +22,35 @@ def __init__(self,
2222
self.minibatch_size = minibatch_size
2323
self.discount_factor = discount_factor
2424
# data
25-
self.frames_seen = 0
2625
self.env = None
2726
self.state = None
2827
self.action = None
28+
self.frames_seen = 0
2929

30-
def initial(self, state, info=None):
30+
def act(self, state, reward):
31+
self._store_transition(state, reward)
32+
self._train()
3133
self.state = state
32-
self.action = self.policy(self.state)
33-
return self.action
34-
35-
def act(self, state, reward, info=None):
36-
self.store_transition(state, reward)
37-
if self.should_train():
38-
self.train()
3934
self.action = self.policy(state)
4035
return self.action
4136

42-
def terminal(self, reward, info=None):
43-
self.store_transition(None, reward)
44-
if self.should_train():
45-
self.train()
46-
47-
def store_transition(self, state, reward):
48-
self.frames_seen += 1
49-
self.replay_buffer.store(self.state, self.action, state, reward)
50-
self.state = state
51-
52-
def should_train(self):
37+
def _store_transition(self, state, reward):
38+
if self.state and not self.state.done:
39+
self.frames_seen += 1
40+
self.replay_buffer.store(self.state, self.action, reward, state)
41+
42+
def _train(self):
43+
if self._should_train():
44+
(states, actions, rewards, next_states, weights) = self.replay_buffer.sample(
45+
self.minibatch_size)
46+
td_errors = (
47+
rewards +
48+
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
49+
self.q(states, actions)
50+
)
51+
self.q.reinforce(weights * td_errors)
52+
self.replay_buffer.update_priorities(td_errors)
53+
54+
def _should_train(self):
5355
return (self.frames_seen > self.replay_start_size and
5456
self.frames_seen % self.update_frequency == 0)
55-
56-
def train(self):
57-
(states, actions, next_states, rewards, weights) = self.replay_buffer.sample(
58-
self.minibatch_size)
59-
td_errors = (
60-
rewards +
61-
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
62-
self.q(states, actions)
63-
)
64-
self.q.reinforce(weights * td_errors)
65-
self.replay_buffer.update_priorities(td_errors)

all/agents/sarsa.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,18 @@ def __init__(self, q, policy, gamma=1):
77
self.policy = policy
88
self.gamma = gamma
99
self.env = None
10-
self.state = None
11-
self.action = None
12-
self.next_state = None
13-
self.next_action = None
10+
self.previous_state = None
11+
self.previous_action = None
1412

15-
def initial(self, state, info=None):
16-
self.state = state
17-
self.action = self.policy(self.state)
18-
return self.action
19-
20-
def act(self, next_state, reward, info=None):
21-
next_action = self.policy(next_state)
22-
td_error = (
23-
reward
24-
+ self.gamma * self.q.eval(next_state, next_action)
25-
- self.q(self.state, self.action)
26-
)
27-
self.q.reinforce(td_error)
28-
self.state = next_state
29-
self.action = next_action
30-
return self.action
31-
32-
def terminal(self, reward, info=None):
33-
td_error = reward - self.q(self.state, self.action)
34-
self.q.reinforce(td_error)
13+
def act(self, state, reward):
14+
action = self.policy(state)
15+
if self.previous_state:
16+
td_error = (
17+
reward
18+
+ self.gamma * self.q.eval(state, action)
19+
- self.q(self.previous_state, self.previous_action)
20+
)
21+
self.q.reinforce(td_error)
22+
self.previous_state = state
23+
self.previous_action = action
24+
return action

all/agents/vpg.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from all.environments import State
23
from .abstract import Agent
34

45
class VPG(Agent):
@@ -16,40 +17,52 @@ def __init__(
1617
self.gamma = gamma
1718
self.n_episodes = n_episodes
1819
self._trajectories = []
19-
self._states = None
20-
self._rewards = None
20+
self._features = []
21+
self._rewards = []
22+
23+
def act(self, state, reward):
24+
if not self._features:
25+
return self._initial(state)
26+
if not state.done:
27+
return self._act(state, reward)
28+
return self._terminal(reward)
2129

22-
def initial(self, state, info=None):
30+
def _initial(self, state):
2331
features = self.features(state)
24-
self._states = [features]
25-
self._rewards = []
32+
self._features = [features.features]
2633
return self.policy(features)
2734

28-
def act(self, state, reward, info=None):
35+
def _act(self, state, reward):
2936
features = self.features(state)
30-
self._states.append(features)
37+
self._features.append(features.features)
3138
self._rewards.append(reward)
3239
return self.policy(features)
3340

34-
def terminal(self, reward, info=None):
41+
def _terminal(self, reward):
3542
self._rewards.append(reward)
36-
states = torch.cat(self._states)
37-
rewards = torch.tensor(self._rewards, device=states.device)
38-
self._trajectories.append((states, rewards))
43+
features = torch.cat(self._features)
44+
rewards = torch.tensor(self._rewards, device=features.device)
45+
self._trajectories.append((features, rewards))
46+
self._features = []
47+
self._rewards = []
48+
3949
if len(self._trajectories) >= self.n_episodes:
40-
advantages = torch.cat([
41-
self._compute_advantages(states, rewards)
42-
for (states, rewards)
43-
in self._trajectories
44-
])
45-
self.v.reinforce(advantages, retain_graph=True)
46-
self.policy.reinforce(advantages)
47-
self.features.reinforce()
48-
self._trajectories = []
50+
self._train()
51+
52+
def _train(self):
53+
advantages = torch.cat([
54+
self._compute_advantages(features, rewards)
55+
for (features, rewards)
56+
in self._trajectories
57+
])
58+
self.v.reinforce(advantages, retain_graph=True)
59+
self.policy.reinforce(advantages)
60+
self.features.reinforce()
61+
self._trajectories = []
4962

5063
def _compute_advantages(self, features, rewards):
5164
returns = self._compute_discounted_returns(rewards)
52-
values = self.v(features)
65+
values = self.v(State(features))
5366
return returns - values
5467

5568
def _compute_discounted_returns(self, rewards):

all/approximation/feature_network.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,54 @@
11
import torch
22
from torch.nn import utils
3-
from all.layers import ListToList
3+
from all.environments import State
44
from .features import Features
55

66
class FeatureNetwork(Features):
77
def __init__(self, model, optimizer, clip_grad=0):
8-
self.model = ListToList(model)
8+
self.model = model
99
self.optimizer = optimizer
1010
self.clip_grad = clip_grad
11+
self._cache = []
12+
self._out = []
1113

1214
def __call__(self, states):
13-
return self.model(states)
15+
features = self.model(states.features.float())
16+
out = features.detach()
17+
out.requires_grad = True
18+
self._cache.append(features)
19+
self._out.append(out)
20+
return State(
21+
out,
22+
mask=states.mask,
23+
info=states.info
24+
)
1425

1526
def eval(self, states):
1627
with torch.no_grad():
1728
training = self.model.training
18-
result = self.model(states)
29+
result = self.model(states.features.float())
1930
self.model.train(training)
20-
return result
31+
return State(
32+
result,
33+
mask=states.mask,
34+
info=states.info
35+
)
2136

2237
def reinforce(self):
23-
# loss comes from elsewhere
38+
graphs, grads = self._decache()
39+
graphs.backward(grads)
2440
if self.clip_grad != 0:
2541
utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
2642
self.optimizer.step()
2743
self.optimizer.zero_grad()
44+
45+
def _decache(self):
46+
graphs = []
47+
grads = []
48+
for graph, out in zip(self._cache, self._out):
49+
if out.grad is not None:
50+
graphs.append(graph)
51+
grads.append(out.grad)
52+
self._cache = []
53+
self._out = []
54+
return torch.cat(graphs), torch.cat(grads)

0 commit comments

Comments
 (0)