Skip to content

Commit

Permalink
Merge pull request #66 from cpnota/release/0.2.0
Browse files Browse the repository at this point in the history
Release/0.2.0
  • Loading branch information
cpnota authored Jun 7, 2019
2 parents 021f0a0 + 1a4c477 commit 9368f98
Show file tree
Hide file tree
Showing 46 changed files with 1,027 additions and 692 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ all.egg-info
local
legacy
/runs
/out
37 changes: 21 additions & 16 deletions all/agents/a2c.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from all.environments import State
from all.memory import NStepBatchBuffer
from .abstract import Agent

Expand All @@ -23,30 +24,34 @@ def __init__(
self.discount_factor = discount_factor
self._batch_size = n_envs * n_steps
self._buffer = self._make_buffer()
self._features = []

def act(self, states, rewards, info=None):
# store transition and train BEFORE choosing action
# Do not need to know actions, so pass in empy array
def act(self, states, rewards):
self._buffer.store(states, torch.zeros(self.n_envs), rewards)
while len(self._buffer) >= self._batch_size:
self._train()
return self.policy(self.features(states))
self._train()
features = self.features(states)
self._features.append(features)
return self.policy(features)

def _train(self):
states, _, next_states, returns, rollout_lengths = self._buffer.sample(self._batch_size)
td_errors = (
returns
+ (self.discount_factor ** rollout_lengths)
* self.v.eval(self.features.eval(next_states))
- self.v(self.features(states))
)
self.v.reinforce(td_errors, retain_graph=True)
self.policy.reinforce(td_errors)
self.features.reinforce()
if len(self._buffer) >= self._batch_size:
states = State.from_list(self._features)
_, _, returns, next_states, rollout_lengths = self._buffer.sample(self._batch_size)
td_errors = (
returns
+ (self.discount_factor ** rollout_lengths)
* self.v.eval(self.features.eval(next_states))
- self.v(states)
)
self.v.reinforce(td_errors)
self.policy.reinforce(td_errors)
self.features.reinforce()
self._features = []

def _make_buffer(self):
return NStepBatchBuffer(
self.n_steps,
self.n_envs,
discount_factor=self.discount_factor
)

40 changes: 1 addition & 39 deletions all/agents/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,8 @@ class Agent(ABC):
An Agent implementation should encapsulate some particular reinforcement learning algorihthm.
"""

def initial(self, state, info=None):
"""
Choose an action in the initial state of a new episode.
Reinforcement learning problems are often broken down into sequences called "episodes".
An episode is a self-contained sequence of states, actions, and rewards.
A "trial" consists of multiple episodes, and represents the lifetime of an agent.
This method is called at the beginning of an episode.
Parameters
----------
state: The initial state of the new episode
info (optional): The info object from the environment
Returns
_______
action: The action to take in the initial state
"""

@abstractmethod
def act(self, state, reward, info=None):
def act(self, state, reward):
"""
Select an action for the current timestep and update internal parameters.
Expand All @@ -53,22 +34,3 @@ def act(self, state, reward, info=None):
_______
action: The action to take at the current timestep
"""

def terminal(self, reward, info=None):
"""
Accept the final reward of the episode and perform final updates.
After the final action is selected, it is still necessary to
consider the reward given on the final timestep. This method
provides a hook where the agent can examine this reward
and perform any necessary updates.
Parameters
----------
reward: The reward from the previous timestep
info (optional): The info object from the environment
Returns
_______
None
"""
19 changes: 7 additions & 12 deletions all/agents/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@ def __init__(self, v, policy, gamma=1):
self.gamma = gamma
self.previous_state = None

def initial(self, state, info=None):
self.previous_state = state
return self.policy(state)

def act(self, state, reward, info=None):
if self.previous_state is not None:
td_error = reward + self.gamma * self.v.eval(state) - self.v(self.previous_state)
def act(self, state, reward):
if self.previous_state:
td_error = (
reward
+ self.gamma * self.v.eval(state)
- self.v(self.previous_state)
)
self.v.reinforce(td_error)
self.policy.reinforce(td_error)
self.previous_state = state
return self.policy(state)

def terminal(self, reward, info=None):
td_error = reward - self.v(self.previous_state)
self.v.reinforce(td_error)
self.policy.reinforce(td_error)
53 changes: 22 additions & 31 deletions all/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,35 @@ def __init__(self,
self.minibatch_size = minibatch_size
self.discount_factor = discount_factor
# data
self.frames_seen = 0
self.env = None
self.state = None
self.action = None
self.frames_seen = 0

def initial(self, state, info=None):
def act(self, state, reward):
self._store_transition(state, reward)
self._train()
self.state = state
self.action = self.policy(self.state)
return self.action

def act(self, state, reward, info=None):
self.store_transition(state, reward)
if self.should_train():
self.train()
self.action = self.policy(state)
return self.action

def terminal(self, reward, info=None):
self.store_transition(None, reward)
if self.should_train():
self.train()

def store_transition(self, state, reward):
self.frames_seen += 1
self.replay_buffer.store(self.state, self.action, state, reward)
self.state = state

def should_train(self):
def _store_transition(self, state, reward):
if self.state and not self.state.done:
self.frames_seen += 1
self.replay_buffer.store(self.state, self.action, reward, state)

def _train(self):
if self._should_train():
(states, actions, rewards, next_states, weights) = self.replay_buffer.sample(
self.minibatch_size)
td_errors = (
rewards +
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
self.q(states, actions)
)
self.q.reinforce(weights * td_errors)
self.replay_buffer.update_priorities(td_errors)

def _should_train(self):
return (self.frames_seen > self.replay_start_size and
self.frames_seen % self.update_frequency == 0)

def train(self):
(states, actions, next_states, rewards, weights) = self.replay_buffer.sample(
self.minibatch_size)
td_errors = (
rewards +
self.discount_factor * torch.max(self.q.eval(next_states), dim=1)[0] -
self.q(states, actions)
)
self.q.reinforce(weights * td_errors)
self.replay_buffer.update_priorities(td_errors)
38 changes: 14 additions & 24 deletions all/agents/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,18 @@ def __init__(self, q, policy, gamma=1):
self.policy = policy
self.gamma = gamma
self.env = None
self.state = None
self.action = None
self.next_state = None
self.next_action = None
self.previous_state = None
self.previous_action = None

def initial(self, state, info=None):
self.state = state
self.action = self.policy(self.state)
return self.action

def act(self, next_state, reward, info=None):
next_action = self.policy(next_state)
td_error = (
reward
+ self.gamma * self.q.eval(next_state, next_action)
- self.q(self.state, self.action)
)
self.q.reinforce(td_error)
self.state = next_state
self.action = next_action
return self.action

def terminal(self, reward, info=None):
td_error = reward - self.q(self.state, self.action)
self.q.reinforce(td_error)
def act(self, state, reward):
action = self.policy(state)
if self.previous_state:
td_error = (
reward
+ self.gamma * self.q.eval(state, action)
- self.q(self.previous_state, self.previous_action)
)
self.q.reinforce(td_error)
self.previous_state = state
self.previous_action = action
return action
55 changes: 34 additions & 21 deletions all/agents/vpg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from all.environments import State
from .abstract import Agent

class VPG(Agent):
Expand All @@ -16,40 +17,52 @@ def __init__(
self.gamma = gamma
self.n_episodes = n_episodes
self._trajectories = []
self._states = None
self._rewards = None
self._features = []
self._rewards = []

def act(self, state, reward):
if not self._features:
return self._initial(state)
if not state.done:
return self._act(state, reward)
return self._terminal(reward)

def initial(self, state, info=None):
def _initial(self, state):
features = self.features(state)
self._states = [features]
self._rewards = []
self._features = [features.features]
return self.policy(features)

def act(self, state, reward, info=None):
def _act(self, state, reward):
features = self.features(state)
self._states.append(features)
self._features.append(features.features)
self._rewards.append(reward)
return self.policy(features)

def terminal(self, reward, info=None):
def _terminal(self, reward):
self._rewards.append(reward)
states = torch.cat(self._states)
rewards = torch.tensor(self._rewards, device=states.device)
self._trajectories.append((states, rewards))
features = torch.cat(self._features)
rewards = torch.tensor(self._rewards, device=features.device)
self._trajectories.append((features, rewards))
self._features = []
self._rewards = []

if len(self._trajectories) >= self.n_episodes:
advantages = torch.cat([
self._compute_advantages(states, rewards)
for (states, rewards)
in self._trajectories
])
self.v.reinforce(advantages, retain_graph=True)
self.policy.reinforce(advantages)
self.features.reinforce()
self._trajectories = []
self._train()

def _train(self):
advantages = torch.cat([
self._compute_advantages(features, rewards)
for (features, rewards)
in self._trajectories
])
self.v.reinforce(advantages, retain_graph=True)
self.policy.reinforce(advantages)
self.features.reinforce()
self._trajectories = []

def _compute_advantages(self, features, rewards):
returns = self._compute_discounted_returns(rewards)
values = self.v(features)
values = self.v(State(features))
return returns - values

def _compute_discounted_returns(self, rewards):
Expand Down
39 changes: 33 additions & 6 deletions all/approximation/feature_network.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
import torch
from torch.nn import utils
from all.layers import ListToList
from all.environments import State
from .features import Features

class FeatureNetwork(Features):
def __init__(self, model, optimizer, clip_grad=0):
self.model = ListToList(model)
self.model = model
self.optimizer = optimizer
self.clip_grad = clip_grad
self._cache = []
self._out = []

def __call__(self, states):
return self.model(states)
features = self.model(states.features.float())
out = features.detach()
out.requires_grad = True
self._cache.append(features)
self._out.append(out)
return State(
out,
mask=states.mask,
info=states.info
)

def eval(self, states):
with torch.no_grad():
training = self.model.training
result = self.model(states)
result = self.model(states.features.float())
self.model.train(training)
return result
return State(
result,
mask=states.mask,
info=states.info
)

def reinforce(self):
# loss comes from elsewhere
graphs, grads = self._decache()
graphs.backward(grads)
if self.clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
self.optimizer.step()
self.optimizer.zero_grad()

def _decache(self):
graphs = []
grads = []
for graph, out in zip(self._cache, self._out):
if out.grad is not None:
graphs.append(graph)
grads.append(out.grad)
self._cache = []
self._out = []
return torch.cat(graphs), torch.cat(grads)
Loading

0 comments on commit 9368f98

Please sign in to comment.