Skip to content

Commit

Permalink
Merge pull request #80 from cpnota/release/0.2.2
Browse files Browse the repository at this point in the history
Release/0.2.2
  • Loading branch information
cpnota authored Jul 20, 2019
2 parents 341e525 + dfef2a5 commit 4d63716
Show file tree
Hide file tree
Showing 54 changed files with 1,342 additions and 487 deletions.
14 changes: 9 additions & 5 deletions all/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from .abstract import Agent
from ._agent import Agent
from .a2c import A2C
from .actor_critic import ActorCritic
from .ddpg import DDPG
from .dqn import DQN
from .sarsa import Sarsa
from .ppo import PPO
from .vac import VAC
from .vpg import VPG
from .vqn import VQN
from .vsarsa import VSarsa

__all__ = [
"Agent",
"A2C",
"ActorCritic",
"DDPG",
"DQN",
"Sarsa",
"PPO",
"VAC",
"VPG",
"VQN",
"VSarsa",
]
File renamed without changes.
21 changes: 9 additions & 12 deletions all/agents/a2c.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from all.environments import State
from all.memory import NStepBatchBuffer
from .abstract import Agent
from all.memory import NStepAdvantageBuffer
from ._agent import Agent


class A2C(Agent):
Expand Down Expand Up @@ -36,20 +36,17 @@ def act(self, states, rewards):
def _train(self):
if len(self._buffer) >= self._batch_size:
states = State.from_list(self._features)
_, _, returns, next_states, rollout_lengths = self._buffer.sample(self._batch_size)
td_errors = (
returns
+ (self.discount_factor ** rollout_lengths)
* self.v.eval(self.features.eval(next_states))
- self.v(states)
)
self.v.reinforce(td_errors)
self.policy.reinforce(td_errors)
_, _, advantages = self._buffer.sample(self._batch_size)
self.v(states)
self.v.reinforce(advantages)
self.policy.reinforce(advantages)
self.features.reinforce()
self._features = []

def _make_buffer(self):
return NStepBatchBuffer(
return NStepAdvantageBuffer(
self.v,
self.features,
self.n_steps,
self.n_envs,
discount_factor=self.discount_factor
Expand Down
20 changes: 0 additions & 20 deletions all/agents/actor_critic.py

This file was deleted.

2 changes: 1 addition & 1 deletion all/agents/ddpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .abstract import Agent
from ._agent import Agent

class DDPG(Agent):
def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion all/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .abstract import Agent
from ._agent import Agent


class DQN(Agent):
Expand Down
92 changes: 92 additions & 0 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from all.memory import GeneralizedAdvantageBuffer
from ._agent import Agent


class PPO(Agent):
def __init__(
self,
features,
v,
policy,
epsilon=0.2,
epochs=4,
minibatches=4,
n_envs=None,
n_steps=4,
discount_factor=0.99,
lam=0.95
):
if n_envs is None:
raise RuntimeError("Must specify n_envs.")
self.features = features
self.v = v
self.policy = policy
self.n_envs = n_envs
self.n_steps = n_steps
self.discount_factor = discount_factor
self.lam = lam
self._epsilon = epsilon
self._epochs = epochs
self._batch_size = n_envs * n_steps
self._minibatches = minibatches
self._buffer = self._make_buffer()
self._features = []

def act(self, states, rewards):
self._train()
actions = self.policy.eval(self.features.eval(states))
self._buffer.store(states, actions, rewards)
return actions

def _train(self):
if len(self._buffer) >= self._batch_size:
states, actions, advantages = self._buffer.sample(self._batch_size)
with torch.no_grad():
features = self.features.eval(states)
pi_0 = self.policy.eval(features, actions)
targets = self.v.eval(features) + advantages
for _ in range(self._epochs):
self._train_epoch(states, actions, pi_0, advantages, targets)

def _train_epoch(self, states, actions, pi_0, advantages, targets):
minibatch_size = int(self._batch_size / self._minibatches)
indexes = torch.randperm(self._batch_size)
for n in range(self._minibatches):
first = n * minibatch_size
last = first + minibatch_size
i = indexes[first:last]
self._train_minibatch(states[i], actions[i], pi_0[i], advantages[i], targets[i])

def _train_minibatch(self, states, actions, pi_0, advantages, targets):
features = self.features(states)
self.policy(features, actions)
self.policy.reinforce(self._compute_policy_loss(pi_0, advantages))
self.v.reinforce(targets - self.v(features))
self.features.reinforce()

def _compute_targets(self, returns, next_states, lengths):
return (
returns +
(self.discount_factor ** lengths)
* self.v.eval(self.features.eval(next_states))
)

def _compute_policy_loss(self, pi_0, advantages):
def _policy_loss(pi_i):
ratios = torch.exp(pi_i - pi_0)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1.0 - self._epsilon, 1.0 + self._epsilon) * advantages
return -torch.min(surr1, surr2).mean()
return _policy_loss

def _make_buffer(self):
return GeneralizedAdvantageBuffer(
self.v,
self.features,
self.n_steps,
self.n_envs,
discount_factor=self.discount_factor,
lam=self.lam
)

23 changes: 23 additions & 0 deletions all/agents/vac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ._agent import Agent

class VAC(Agent):
'''Vanilla Actor-Critic'''
def __init__(self, features, v, policy, gamma=1):
self.features = features
self.v = v
self.policy = policy
self.gamma = gamma
self._previous_features = None

def act(self, state, reward):
if self._previous_features:
td_error = (
reward
+ self.gamma * self.v.eval(self.features.eval(state))
- self.v(self._previous_features)
)
self.v.reinforce(td_error)
self.policy.reinforce(td_error)
self.features.reinforce()
self._previous_features = self.features(state)
return self.policy(self._previous_features)
14 changes: 10 additions & 4 deletions all/agents/vpg.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import torch
from all.environments import State
from .abstract import Agent
from ._agent import Agent

class VPG(Agent):
'''Vanilla Policy Gradient'''
def __init__(
self,
features,
v,
policy,
gamma=0.99,
n_episodes=1
# run complete episodes until we have
# seen at least min_batch_size states
min_batch_size=1
):
self.features = features
self.v = v
self.policy = policy
self.gamma = gamma
self.n_episodes = n_episodes
self.min_batch_size = min_batch_size
self._current_batch_size = 0
self._trajectories = []
self._features = []
self._rewards = []
Expand Down Expand Up @@ -43,10 +47,11 @@ def _terminal(self, reward):
features = torch.cat(self._features)
rewards = torch.tensor(self._rewards, device=features.device)
self._trajectories.append((features, rewards))
self._current_batch_size += len(features)
self._features = []
self._rewards = []

if len(self._trajectories) >= self.n_episodes:
if self._current_batch_size >= self.min_batch_size:
self._train()

def _train(self):
Expand All @@ -59,6 +64,7 @@ def _train(self):
self.policy.reinforce(advantages)
self.features.reinforce()
self._trajectories = []
self._current_batch_size = 0

def _compute_advantages(self, features, rewards):
returns = self._compute_discounted_returns(rewards)
Expand Down
26 changes: 26 additions & 0 deletions all/agents/vqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from ._agent import Agent


class VQN(Agent):
'''Vanilla Q-Network'''
def __init__(self, q, policy, gamma=1):
self.q = q
self.policy = policy
self.gamma = gamma
self.env = None
self.previous_state = None
self.previous_action = None

def act(self, state, reward):
action = self.policy(state)
if self.previous_state:
td_error = (
reward
+ self.gamma * torch.max(self.q.eval(state), dim=1)[0]
- self.q(self.previous_state, self.previous_action)
)
self.q.reinforce(td_error)
self.previous_state = state
self.previous_action = action
return action
5 changes: 3 additions & 2 deletions all/agents/sarsa.py → all/agents/vsarsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .abstract import Agent
from ._agent import Agent


class Sarsa(Agent):
class VSarsa(Agent):
'''Vanilla SARSA'''
def __init__(self, q, policy, gamma=1):
self.q = q
self.policy = policy
Expand Down
2 changes: 1 addition & 1 deletion all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,5 @@ def _convert(self, action):
if isinstance(self.action_space, gym.spaces.Discrete):
return action.item()
if isinstance(self.action_space, gym.spaces.Box):
return action.cpu().detach().numpy()[0]
return action.cpu().detach().numpy().reshape(-1)
raise TypeError("Unknown action space type")
6 changes: 6 additions & 0 deletions all/environments/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def __getitem__(self, idx):
self._mask[idx],
self._info[idx]
)
if isinstance(idx, torch.Tensor):
return State(
self._raw[idx],
self._mask[idx],
# can't copy info
)
return State(
self._raw[idx].unsqueeze(0),
self._mask[idx].unsqueeze(0),
Expand Down
14 changes: 7 additions & 7 deletions all/experiments/experiment_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import numpy as np
import torch
from all.presets.classic_control import sarsa
from all.presets.classic_control import dqn
from all.experiments import Experiment, Writer


Expand Down Expand Up @@ -51,24 +51,24 @@ def setUp(self):
self.experiment.env.seed(0)

def test_adds_label(self):
self.experiment.run(sarsa(), console=False)
self.assertEqual(self.experiment._writer.label, "_sarsa")
self.experiment.run(dqn(), console=False)
self.assertEqual(self.experiment._writer.label, "_dqn")

def test_writes_returns_eps(self):
self.experiment.run(sarsa(), console=False)
self.experiment.run(dqn(), console=False)
np.testing.assert_equal(
self.experiment._writer.data["evaluation/returns-by-episode"]["values"],
np.array([9., 12., 10.])
np.array([14., 19., 26.])
)
np.testing.assert_equal(
self.experiment._writer.data["evaluation/returns-by-episode"]["steps"],
np.array([1, 2, 3])
)

def test_writes_loss(self):
self.experiment.run(sarsa(), console=False)
self.experiment.run(dqn(), console=False)
self.assertTrue(self.experiment._writer.write_loss)
self.experiment.run(sarsa(), console=False, write_loss=False)
self.experiment.run(dqn(), console=False, write_loss=False)
self.assertFalse(self.experiment._writer.write_loss)

if __name__ == '__main__':
Expand Down
6 changes: 5 additions & 1 deletion all/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from .replay_buffer import ReplayBuffer, ExperienceReplayBuffer, PrioritizedReplayBuffer
from .n_step import NStepBuffer, NStepBatchBuffer
from .advantage import NStepAdvantageBuffer
from .generalized_advantage import GeneralizedAdvantageBuffer

__all__ = [
"ReplayBuffer",
"ExperienceReplayBuffer",
"PrioritizedReplayBuffer",
"NStepBuffer",
"NStepBatchBuffer"
"NStepBatchBuffer",
"NStepAdvantageBuffer",
"GeneralizedAdvantageBuffer",
]
Loading

0 comments on commit 4d63716

Please sign in to comment.