Skip to content

Commit

Permalink
Merge pull request #76 from cpnota/release/0.2.1
Browse files Browse the repository at this point in the history
Release/0.2.1
  • Loading branch information
cpnota authored Jul 12, 2019
2 parents 9368f98 + e0bfd25 commit 341e525
Show file tree
Hide file tree
Showing 49 changed files with 1,000 additions and 360 deletions.
1 change: 1 addition & 0 deletions all/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import all.nn
2 changes: 2 additions & 0 deletions all/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .abstract import Agent
from .a2c import A2C
from .actor_critic import ActorCritic
from .ddpg import DDPG
from .dqn import DQN
from .sarsa import Sarsa
from .vpg import VPG
Expand All @@ -9,6 +10,7 @@
"Agent",
"A2C",
"ActorCritic",
"DDPG",
"DQN",
"Sarsa",
"VPG",
Expand Down
63 changes: 63 additions & 0 deletions all/agents/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from .abstract import Agent

class DDPG(Agent):
def __init__(self,
q,
policy,
replay_buffer,
discount_factor=0.99,
minibatch_size=32,
replay_start_size=5000,
update_frequency=1
):
# objects
self.q = q
self.policy = policy
self.replay_buffer = replay_buffer
# hyperparameters
self.replay_start_size = replay_start_size
self.update_frequency = update_frequency
self.minibatch_size = minibatch_size
self.discount_factor = discount_factor
# data
self.env = None
self.state = None
self.action = None
self.frames_seen = 0

def act(self, state, reward):
self._store_transition(state, reward)
self._train()
self.state = state
self.action = self.policy(state)
return self.action

def _store_transition(self, state, reward):
if self.state and not self.state.done:
self.frames_seen += 1
self.replay_buffer.store(self.state, self.action, reward, state)

def _train(self):
if self._should_train():
(states, actions, rewards, next_states, weights) = self.replay_buffer.sample(
self.minibatch_size)

# train q function
td_errors = (
rewards +
self.discount_factor * self.q.eval(next_states, self.policy.eval(next_states)) -
self.q(states, torch.cat(actions))
)
self.q.reinforce(weights * td_errors)
self.replay_buffer.update_priorities(td_errors)

# train policy
loss = -self.q(states, self.policy.greedy(states), detach=False).mean()
loss.backward()
self.policy.step()
self.q.zero_grad()

def _should_train(self):
return (self.frames_seen > self.replay_start_size and
self.frames_seen % self.update_frequency == 0)
17 changes: 4 additions & 13 deletions all/approximation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from .q_function import QFunction
from .approximation import Approximation
from .q_continuous import QContinuous
from .q_network import QNetwork
from .v_function import ValueFunction
from .v_network import ValueNetwork
from .features import Features
from .v_network import VNetwork
from .feature_network import FeatureNetwork

__all__ = [
"QFunction",
"QNetwork",
"ValueFunction",
"ValueNetwork",
"Features",
"FeatureNetwork"
]
from .target import TargetNetwork, FixedTarget, PolyakTarget, TrivialTarget
82 changes: 82 additions & 0 deletions all/approximation/approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from torch.nn import utils
from torch.nn.functional import mse_loss
from all.experiments import DummyWriter
from .target import FixedTarget, TrivialTarget

class Approximation():
def __init__(
self,
model,
optimizer,
clip_grad=0,
loss_scaling=1,
loss=mse_loss,
name='approximation',
target=None,
writer=DummyWriter(),
):
self.model = model
self.device = next(model.parameters()).device
self._target = target or TrivialTarget()
self._target.init(model)
self._updates = 0
self._optimizer = optimizer
self._loss = loss
self._loss_scaling = loss_scaling
self._cache = []
self._clip_grad = clip_grad
self._writer = writer
self._name = name

def __call__(self, *inputs, detach=True):
result = self.model(*inputs)
if detach:
self._enqueue(result)
return result.detach()
return result

def eval(self, *inputs):
return self._target(*inputs)

def reinforce(self, errors, retain_graph=False):
batch_size = len(errors)
cache = self._dequeue(batch_size)
if cache.requires_grad:
loss = self._loss(cache, errors) * self._loss_scaling
self._writer.add_loss(self._name, loss)
loss.backward(retain_graph=retain_graph)
self.step()

def step(self):
if self._clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad)
self._optimizer.step()
self._optimizer.zero_grad()
self._target.update()

def zero_grad(self):
self._optimizer.zero_grad()

def _enqueue(self, results):
self._cache.append(results)

def _dequeue(self, batch_size):
i = 0
num_items = 0
while num_items < batch_size and i < len(self._cache):
num_items += len(self._cache[i])
i += 1
if num_items != batch_size:
raise ValueError("Incompatible batch size.")
items = torch.cat(self._cache[:i])
self._cache = self._cache[i:]
return items

def _init_target_model(self, target_update_frequency):
if target_update_frequency is not None:
self._target = FixedTarget(target_update_frequency)
self._target.init(self.model)
else:
self._target = TrivialTarget()
self._target.init(self.model)
42 changes: 18 additions & 24 deletions all/approximation/feature_network.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,42 @@
import torch
from torch.nn import utils
from all.environments import State
from .features import Features
from .approximation import Approximation

class FeatureNetwork(Features):
def __init__(self, model, optimizer, clip_grad=0):
self.model = model
self.optimizer = optimizer
self.clip_grad = clip_grad
class FeatureNetwork(Approximation):
def __init__(self, model, optimizer=None, **kwargs):
super().__init__(model, optimizer, **kwargs)
self._cache = []
self._out = []

def __call__(self, states):
features = self.model(states.features.float())
out = features.detach()
out.requires_grad = True
self._cache.append(features)
self._out.append(out)
self._enqueue(features, out)
return State(
out,
mask=states.mask,
info=states.info
)

def eval(self, states):
with torch.no_grad():
training = self.model.training
result = self.model(states.features.float())
self.model.train(training)
return State(
result,
mask=states.mask,
info=states.info
)
result = self._target(states.features.float())
return State(
result,
mask=states.mask,
info=states.info
)

def reinforce(self):
graphs, grads = self._decache()
graphs, grads = self._dequeue()
graphs.backward(grads)
if self.clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
self.optimizer.step()
self.optimizer.zero_grad()
self.step()

def _enqueue(self, features, out):
self._cache.append(features)
self._out.append(out)

def _decache(self):
def _dequeue(self):
graphs = []
grads = []
for graph, out in zip(self._cache, self._out):
Expand Down
14 changes: 0 additions & 14 deletions all/approximation/features.py

This file was deleted.

22 changes: 22 additions & 0 deletions all/approximation/q_continuous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.nn.functional import mse_loss
from all.nn import QModuleContinuous, td_loss
from .approximation import Approximation

class QContinuous(Approximation):
def __init__(
self,
model,
optimizer,
loss=mse_loss,
name='q',
**kwargs
):
model = QModuleContinuous(model)
loss = td_loss(loss)
super().__init__(
model,
optimizer,
loss=loss,
name=name,
**kwargs
)
14 changes: 0 additions & 14 deletions all/approximation/q_function.py

This file was deleted.

77 changes: 14 additions & 63 deletions all/approximation/q_network.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,23 @@
import copy
import torch
from torch import optim
from torch.nn.functional import mse_loss
from all.experiments import DummyWriter
from all.layers import ListNetwork
from .q_function import QFunction
from all.nn import QModule, td_loss
from .approximation import Approximation


class QNetwork(QFunction):
class QNetwork(Approximation):
def __init__(
self,
model,
optimizer,
actions,
num_actions,
loss=mse_loss,
target_update_frequency=None,
writer=DummyWriter()
name='q',
**kwargs
):
self.model = ListNetwork(model, (actions,))
self.optimizer = (optimizer
if optimizer is not None
else optim.Adam(model.parameters()))
self.loss = loss
self.cache = None
self.updates = 0
self.target_update_frequency = target_update_frequency
self.target_model = (
copy.deepcopy(self.model)
if target_update_frequency is not None
else self.model
)
self.device = next(model.parameters()).device
self.writer = writer

def __call__(self, states, actions=None):
result = self._eval(states, actions, self.model)
if result.requires_grad:
self.cache = result
return result.detach()

def eval(self, states, actions=None):
with torch.no_grad():
training = self.target_model.training
result = self._eval(states, actions, self.target_model.eval())
self.target_model.train(training)
return result

def reinforce(self, td_errors):
targets = td_errors + self.cache.detach()
loss = self.loss(self.cache, targets)
self.writer.add_loss('q', loss)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.updates += 1
if self.should_update_target():
self.target_model.load_state_dict(self.model.state_dict())

def _eval(self, states, actions, model):
values = model(states)
if actions is None:
return values
if isinstance(actions, list):
actions = torch.tensor(actions, device=self.device)
return values.gather(1, actions.view(-1, 1)).squeeze(1)

def should_update_target(self):
return (
(self.target_update_frequency is not None)
and (self.updates % self.target_update_frequency == 0)
model = QModule(model, num_actions)
loss = td_loss(loss)
super().__init__(
model,
optimizer,
loss=loss,
name=name,
**kwargs
)
Loading

0 comments on commit 341e525

Please sign in to comment.