-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from cpnota/release/0.2.2
Release/0.2.2
- Loading branch information
Showing
54 changed files
with
1,342 additions
and
487 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,21 @@ | ||
from .abstract import Agent | ||
from ._agent import Agent | ||
from .a2c import A2C | ||
from .actor_critic import ActorCritic | ||
from .ddpg import DDPG | ||
from .dqn import DQN | ||
from .sarsa import Sarsa | ||
from .ppo import PPO | ||
from .vac import VAC | ||
from .vpg import VPG | ||
from .vqn import VQN | ||
from .vsarsa import VSarsa | ||
|
||
__all__ = [ | ||
"Agent", | ||
"A2C", | ||
"ActorCritic", | ||
"DDPG", | ||
"DQN", | ||
"Sarsa", | ||
"PPO", | ||
"VAC", | ||
"VPG", | ||
"VQN", | ||
"VSarsa", | ||
] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import torch | ||
from .abstract import Agent | ||
from ._agent import Agent | ||
|
||
class DDPG(Agent): | ||
def __init__(self, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import torch | ||
from .abstract import Agent | ||
from ._agent import Agent | ||
|
||
|
||
class DQN(Agent): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
from all.memory import GeneralizedAdvantageBuffer | ||
from ._agent import Agent | ||
|
||
|
||
class PPO(Agent): | ||
def __init__( | ||
self, | ||
features, | ||
v, | ||
policy, | ||
epsilon=0.2, | ||
epochs=4, | ||
minibatches=4, | ||
n_envs=None, | ||
n_steps=4, | ||
discount_factor=0.99, | ||
lam=0.95 | ||
): | ||
if n_envs is None: | ||
raise RuntimeError("Must specify n_envs.") | ||
self.features = features | ||
self.v = v | ||
self.policy = policy | ||
self.n_envs = n_envs | ||
self.n_steps = n_steps | ||
self.discount_factor = discount_factor | ||
self.lam = lam | ||
self._epsilon = epsilon | ||
self._epochs = epochs | ||
self._batch_size = n_envs * n_steps | ||
self._minibatches = minibatches | ||
self._buffer = self._make_buffer() | ||
self._features = [] | ||
|
||
def act(self, states, rewards): | ||
self._train() | ||
actions = self.policy.eval(self.features.eval(states)) | ||
self._buffer.store(states, actions, rewards) | ||
return actions | ||
|
||
def _train(self): | ||
if len(self._buffer) >= self._batch_size: | ||
states, actions, advantages = self._buffer.sample(self._batch_size) | ||
with torch.no_grad(): | ||
features = self.features.eval(states) | ||
pi_0 = self.policy.eval(features, actions) | ||
targets = self.v.eval(features) + advantages | ||
for _ in range(self._epochs): | ||
self._train_epoch(states, actions, pi_0, advantages, targets) | ||
|
||
def _train_epoch(self, states, actions, pi_0, advantages, targets): | ||
minibatch_size = int(self._batch_size / self._minibatches) | ||
indexes = torch.randperm(self._batch_size) | ||
for n in range(self._minibatches): | ||
first = n * minibatch_size | ||
last = first + minibatch_size | ||
i = indexes[first:last] | ||
self._train_minibatch(states[i], actions[i], pi_0[i], advantages[i], targets[i]) | ||
|
||
def _train_minibatch(self, states, actions, pi_0, advantages, targets): | ||
features = self.features(states) | ||
self.policy(features, actions) | ||
self.policy.reinforce(self._compute_policy_loss(pi_0, advantages)) | ||
self.v.reinforce(targets - self.v(features)) | ||
self.features.reinforce() | ||
|
||
def _compute_targets(self, returns, next_states, lengths): | ||
return ( | ||
returns + | ||
(self.discount_factor ** lengths) | ||
* self.v.eval(self.features.eval(next_states)) | ||
) | ||
|
||
def _compute_policy_loss(self, pi_0, advantages): | ||
def _policy_loss(pi_i): | ||
ratios = torch.exp(pi_i - pi_0) | ||
surr1 = ratios * advantages | ||
surr2 = torch.clamp(ratios, 1.0 - self._epsilon, 1.0 + self._epsilon) * advantages | ||
return -torch.min(surr1, surr2).mean() | ||
return _policy_loss | ||
|
||
def _make_buffer(self): | ||
return GeneralizedAdvantageBuffer( | ||
self.v, | ||
self.features, | ||
self.n_steps, | ||
self.n_envs, | ||
discount_factor=self.discount_factor, | ||
lam=self.lam | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from ._agent import Agent | ||
|
||
class VAC(Agent): | ||
'''Vanilla Actor-Critic''' | ||
def __init__(self, features, v, policy, gamma=1): | ||
self.features = features | ||
self.v = v | ||
self.policy = policy | ||
self.gamma = gamma | ||
self._previous_features = None | ||
|
||
def act(self, state, reward): | ||
if self._previous_features: | ||
td_error = ( | ||
reward | ||
+ self.gamma * self.v.eval(self.features.eval(state)) | ||
- self.v(self._previous_features) | ||
) | ||
self.v.reinforce(td_error) | ||
self.policy.reinforce(td_error) | ||
self.features.reinforce() | ||
self._previous_features = self.features(state) | ||
return self.policy(self._previous_features) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
from ._agent import Agent | ||
|
||
|
||
class VQN(Agent): | ||
'''Vanilla Q-Network''' | ||
def __init__(self, q, policy, gamma=1): | ||
self.q = q | ||
self.policy = policy | ||
self.gamma = gamma | ||
self.env = None | ||
self.previous_state = None | ||
self.previous_action = None | ||
|
||
def act(self, state, reward): | ||
action = self.policy(state) | ||
if self.previous_state: | ||
td_error = ( | ||
reward | ||
+ self.gamma * torch.max(self.q.eval(state), dim=1)[0] | ||
- self.q(self.previous_state, self.previous_action) | ||
) | ||
self.q.reinforce(td_error) | ||
self.previous_state = state | ||
self.previous_action = action | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
from .replay_buffer import ReplayBuffer, ExperienceReplayBuffer, PrioritizedReplayBuffer | ||
from .n_step import NStepBuffer, NStepBatchBuffer | ||
from .advantage import NStepAdvantageBuffer | ||
from .generalized_advantage import GeneralizedAdvantageBuffer | ||
|
||
__all__ = [ | ||
"ReplayBuffer", | ||
"ExperienceReplayBuffer", | ||
"PrioritizedReplayBuffer", | ||
"NStepBuffer", | ||
"NStepBatchBuffer" | ||
"NStepBatchBuffer", | ||
"NStepAdvantageBuffer", | ||
"GeneralizedAdvantageBuffer", | ||
] |
Oops, something went wrong.