-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #76 from cpnota/release/0.2.1
Release/0.2.1
- Loading branch information
Showing
49 changed files
with
1,000 additions
and
360 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import all.nn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from .abstract import Agent | ||
|
||
class DDPG(Agent): | ||
def __init__(self, | ||
q, | ||
policy, | ||
replay_buffer, | ||
discount_factor=0.99, | ||
minibatch_size=32, | ||
replay_start_size=5000, | ||
update_frequency=1 | ||
): | ||
# objects | ||
self.q = q | ||
self.policy = policy | ||
self.replay_buffer = replay_buffer | ||
# hyperparameters | ||
self.replay_start_size = replay_start_size | ||
self.update_frequency = update_frequency | ||
self.minibatch_size = minibatch_size | ||
self.discount_factor = discount_factor | ||
# data | ||
self.env = None | ||
self.state = None | ||
self.action = None | ||
self.frames_seen = 0 | ||
|
||
def act(self, state, reward): | ||
self._store_transition(state, reward) | ||
self._train() | ||
self.state = state | ||
self.action = self.policy(state) | ||
return self.action | ||
|
||
def _store_transition(self, state, reward): | ||
if self.state and not self.state.done: | ||
self.frames_seen += 1 | ||
self.replay_buffer.store(self.state, self.action, reward, state) | ||
|
||
def _train(self): | ||
if self._should_train(): | ||
(states, actions, rewards, next_states, weights) = self.replay_buffer.sample( | ||
self.minibatch_size) | ||
|
||
# train q function | ||
td_errors = ( | ||
rewards + | ||
self.discount_factor * self.q.eval(next_states, self.policy.eval(next_states)) - | ||
self.q(states, torch.cat(actions)) | ||
) | ||
self.q.reinforce(weights * td_errors) | ||
self.replay_buffer.update_priorities(td_errors) | ||
|
||
# train policy | ||
loss = -self.q(states, self.policy.greedy(states), detach=False).mean() | ||
loss.backward() | ||
self.policy.step() | ||
self.q.zero_grad() | ||
|
||
def _should_train(self): | ||
return (self.frames_seen > self.replay_start_size and | ||
self.frames_seen % self.update_frequency == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,6 @@ | ||
from .q_function import QFunction | ||
from .approximation import Approximation | ||
from .q_continuous import QContinuous | ||
from .q_network import QNetwork | ||
from .v_function import ValueFunction | ||
from .v_network import ValueNetwork | ||
from .features import Features | ||
from .v_network import VNetwork | ||
from .feature_network import FeatureNetwork | ||
|
||
__all__ = [ | ||
"QFunction", | ||
"QNetwork", | ||
"ValueFunction", | ||
"ValueNetwork", | ||
"Features", | ||
"FeatureNetwork" | ||
] | ||
from .target import TargetNetwork, FixedTarget, PolyakTarget, TrivialTarget |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
from torch.nn import utils | ||
from torch.nn.functional import mse_loss | ||
from all.experiments import DummyWriter | ||
from .target import FixedTarget, TrivialTarget | ||
|
||
class Approximation(): | ||
def __init__( | ||
self, | ||
model, | ||
optimizer, | ||
clip_grad=0, | ||
loss_scaling=1, | ||
loss=mse_loss, | ||
name='approximation', | ||
target=None, | ||
writer=DummyWriter(), | ||
): | ||
self.model = model | ||
self.device = next(model.parameters()).device | ||
self._target = target or TrivialTarget() | ||
self._target.init(model) | ||
self._updates = 0 | ||
self._optimizer = optimizer | ||
self._loss = loss | ||
self._loss_scaling = loss_scaling | ||
self._cache = [] | ||
self._clip_grad = clip_grad | ||
self._writer = writer | ||
self._name = name | ||
|
||
def __call__(self, *inputs, detach=True): | ||
result = self.model(*inputs) | ||
if detach: | ||
self._enqueue(result) | ||
return result.detach() | ||
return result | ||
|
||
def eval(self, *inputs): | ||
return self._target(*inputs) | ||
|
||
def reinforce(self, errors, retain_graph=False): | ||
batch_size = len(errors) | ||
cache = self._dequeue(batch_size) | ||
if cache.requires_grad: | ||
loss = self._loss(cache, errors) * self._loss_scaling | ||
self._writer.add_loss(self._name, loss) | ||
loss.backward(retain_graph=retain_graph) | ||
self.step() | ||
|
||
def step(self): | ||
if self._clip_grad != 0: | ||
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad) | ||
self._optimizer.step() | ||
self._optimizer.zero_grad() | ||
self._target.update() | ||
|
||
def zero_grad(self): | ||
self._optimizer.zero_grad() | ||
|
||
def _enqueue(self, results): | ||
self._cache.append(results) | ||
|
||
def _dequeue(self, batch_size): | ||
i = 0 | ||
num_items = 0 | ||
while num_items < batch_size and i < len(self._cache): | ||
num_items += len(self._cache[i]) | ||
i += 1 | ||
if num_items != batch_size: | ||
raise ValueError("Incompatible batch size.") | ||
items = torch.cat(self._cache[:i]) | ||
self._cache = self._cache[i:] | ||
return items | ||
|
||
def _init_target_model(self, target_update_frequency): | ||
if target_update_frequency is not None: | ||
self._target = FixedTarget(target_update_frequency) | ||
self._target.init(self.model) | ||
else: | ||
self._target = TrivialTarget() | ||
self._target.init(self.model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from torch.nn.functional import mse_loss | ||
from all.nn import QModuleContinuous, td_loss | ||
from .approximation import Approximation | ||
|
||
class QContinuous(Approximation): | ||
def __init__( | ||
self, | ||
model, | ||
optimizer, | ||
loss=mse_loss, | ||
name='q', | ||
**kwargs | ||
): | ||
model = QModuleContinuous(model) | ||
loss = td_loss(loss) | ||
super().__init__( | ||
model, | ||
optimizer, | ||
loss=loss, | ||
name=name, | ||
**kwargs | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,23 @@ | ||
import copy | ||
import torch | ||
from torch import optim | ||
from torch.nn.functional import mse_loss | ||
from all.experiments import DummyWriter | ||
from all.layers import ListNetwork | ||
from .q_function import QFunction | ||
from all.nn import QModule, td_loss | ||
from .approximation import Approximation | ||
|
||
|
||
class QNetwork(QFunction): | ||
class QNetwork(Approximation): | ||
def __init__( | ||
self, | ||
model, | ||
optimizer, | ||
actions, | ||
num_actions, | ||
loss=mse_loss, | ||
target_update_frequency=None, | ||
writer=DummyWriter() | ||
name='q', | ||
**kwargs | ||
): | ||
self.model = ListNetwork(model, (actions,)) | ||
self.optimizer = (optimizer | ||
if optimizer is not None | ||
else optim.Adam(model.parameters())) | ||
self.loss = loss | ||
self.cache = None | ||
self.updates = 0 | ||
self.target_update_frequency = target_update_frequency | ||
self.target_model = ( | ||
copy.deepcopy(self.model) | ||
if target_update_frequency is not None | ||
else self.model | ||
) | ||
self.device = next(model.parameters()).device | ||
self.writer = writer | ||
|
||
def __call__(self, states, actions=None): | ||
result = self._eval(states, actions, self.model) | ||
if result.requires_grad: | ||
self.cache = result | ||
return result.detach() | ||
|
||
def eval(self, states, actions=None): | ||
with torch.no_grad(): | ||
training = self.target_model.training | ||
result = self._eval(states, actions, self.target_model.eval()) | ||
self.target_model.train(training) | ||
return result | ||
|
||
def reinforce(self, td_errors): | ||
targets = td_errors + self.cache.detach() | ||
loss = self.loss(self.cache, targets) | ||
self.writer.add_loss('q', loss) | ||
loss.backward() | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
self.updates += 1 | ||
if self.should_update_target(): | ||
self.target_model.load_state_dict(self.model.state_dict()) | ||
|
||
def _eval(self, states, actions, model): | ||
values = model(states) | ||
if actions is None: | ||
return values | ||
if isinstance(actions, list): | ||
actions = torch.tensor(actions, device=self.device) | ||
return values.gather(1, actions.view(-1, 1)).squeeze(1) | ||
|
||
def should_update_target(self): | ||
return ( | ||
(self.target_update_frequency is not None) | ||
and (self.updates % self.target_update_frequency == 0) | ||
model = QModule(model, num_actions) | ||
loss = td_loss(loss) | ||
super().__init__( | ||
model, | ||
optimizer, | ||
loss=loss, | ||
name=name, | ||
**kwargs | ||
) |
Oops, something went wrong.