Skip to content

Building blocks for PEBBLE #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
8d5900a
Welfords alg and test
dan-pandori Nov 10, 2022
4aac074
Next func
dan-pandori Nov 10, 2022
383fce0
Test update
dan-pandori Nov 10, 2022
055fa67
compute_state_entropy and test
dan-pandori Nov 11, 2022
5c278f4
Sketch of the entropy reward replay buffer
dan-pandori Nov 11, 2022
49dc26f
Batchify state entropy func
dan-pandori Nov 11, 2022
394ad56
Final sketch of replay entropy buffer.
dan-pandori Nov 11, 2022
21da532
First test
dan-pandori Nov 11, 2022
15dad99
Test cleanup
dan-pandori Nov 11, 2022
0c28079
Update
dan-pandori Nov 11, 2022
5ab9d28
Commit for diff
dan-pandori Nov 12, 2022
9410c31
Push final-ish state
dan-pandori Nov 12, 2022
fdcdf0d
#625 refactor RunningMeanAndVar
Nov 29, 2022
0cd1255
#625 use RunningNorm instead of RunningMeanAndVar
Nov 29, 2022
d88ba44
#625 make copy of train_preference_comparisons.py for pebble
Nov 29, 2022
2d836de
#625 use an OffPolicy for pebble
Nov 29, 2022
ec5f67e
#625 fix assumptions about shapes in ReplayBufferEntropyRewardWrapper
Nov 30, 2022
da228bd
#625 entropy reward as a function
Nov 30, 2022
1ec645a
#625 make entropy reward serializable with pickle
Dec 1, 2022
4e16c42
#625 revert change of compute_state_entropy() from tensors to numpy
Dec 1, 2022
acb51be
#625 extract _preference_feedback_schedule()
Dec 1, 2022
8143ba3
#625 introduce parameter for pretraining steps
Dec 1, 2022
184e191
#625 add initialized callback to ReplayBufferRewardWrapper
Dec 1, 2022
52d914a
#625 fix entropy_reward.py
Dec 1, 2022
1f01a7a
#625 remove ReplayBufferEntropyRewardWrapper
Dec 1, 2022
1fbc590
#625 introduce ReplayBufferAwareRewardFn
Dec 1, 2022
e19dd85
#625 rename PebbleStateEntropyReward
Dec 1, 2022
da77f5c
#625 PebbleStateEntropyReward can switch from unsupervised pretraining
Dec 1, 2022
a11e775
#625 add optional pretraining to PreferenceComparisons
Dec 1, 2022
7b12162
#625 PebbleStateEntropyReward supports the initial phase before repla…
Dec 1, 2022
e354e16
#625 entropy_reward can automatically detect if enough observations a…
Dec 1, 2022
b8ccf2f
#625 fix entropy shape
Dec 1, 2022
c5f1dba
#625 rename unsupervised_agent_pretrain_frac parameter
Dec 1, 2022
0ba8959
#625 specialized PebbleAgentTrainer to distinguish from old preferenc…
Dec 1, 2022
c55fee7
#625 merge pebble to train_preference_comparisons.py and configure on…
Dec 1, 2022
1f9642a
#625 plug in pebble according to parameters
Dec 1, 2022
6f05b1d
#625 fix pre-commit errors
Dec 1, 2022
c787877
#625 add test for pebble agent trainer
Dec 1, 2022
b9c5614
#625 fix more pre-commit errors
Dec 1, 2022
40e7387
#625 fix even more pre-commit errors
Dec 2, 2022
aad2e7c
code review - Update src/imitation/policies/replay_buffer_wrapper.py
mifeet Dec 2, 2022
e0aea61
#625 code review
Dec 2, 2022
f0a3359
#625 code review: do not allocate timesteps for pretraining if there …
Dec 2, 2022
8cb2449
Update src/imitation/algorithms/preference_comparisons.py
mifeet Dec 2, 2022
378baa8
#625 code review: remove ignore
Dec 2, 2022
d7ad414
#625 code review - skip pretrainining if zero timesteps
Dec 2, 2022
412550d
#625 code review: separate pebble and environment configuration
Dec 2, 2022
7c3470e
#625 fix even even more pre-commit errors
Dec 2, 2022
73b1e36
#625 fix even even more pre-commit errors
Dec 2, 2022
6daa473
#641 code review: remove set_replay_buffer
Dec 7, 2022
c80fb80
#641 code review: fix comment
Dec 7, 2022
50577b0
#641 code review: replace RunningNorm with NormalizedRewardNet
Dec 10, 2022
531b353
#641 code review: refactor PebbleStateEntropyReward so that inner Rew…
Dec 10, 2022
74ba96b
#641 fix static analysis and tests
Dec 10, 2022
b344cbd
#641 increase coverage
Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/imitation/algorithms/pebble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""PEBBLE specific algorithms."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit odd that we have preference_comparisons.py in a single file but PEBBLE (much smaller) split across several files. That's probably a sign we should split up preference_comparisons.py not aggregate PEBBLE though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that, e.g., classes for work with fragments and preference gathering seem like independent pieces of logic. Probably for another PR, though.

125 changes: 125 additions & 0 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Reward function for the PEBBLE training algorithm."""

import enum
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch as th

from imitation.policies.replay_buffer_wrapper import (
ReplayBufferAwareRewardFn,
ReplayBufferRewardWrapper,
ReplayBufferView,
)
from imitation.rewards.reward_function import RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


class PebbleRewardPhase(enum.Enum):
"""States representing different behaviors for PebbleStateEntropyReward."""

UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
"""Reward function for implementation of the PEBBLE learning algorithm.

See https://arxiv.org/abs/2106.05091 .

The rewards returned by this function go through the three phases:
1. Before enough samples are collected for entropy calculation, the
underlying function is returned. This shouldn't matter because
OffPolicyAlgorithms have an initialization period for `learning_starts`
timesteps.
2. During the unsupervised exploration phase, entropy based reward is returned
3. After unsupervised exploration phase is finished, the underlying learned
reward is returned.

The second phase requires that a buffer with observations to compare against is
supplied with set_replay_buffer() or on_replay_buffer_initialized().
To transition to the last phase, unsupervised_exploration_finish() needs
to be called.
"""

def __init__(
self,
learned_reward_fn: RewardFn,
nearest_neighbor_k: int = 5,
):
"""Builds this class.

Args:
learned_reward_fn: The learned reward function used after unsupervised
exploration is finished
nearest_neighbor_k: Parameter for entropy computation (see
compute_state_entropy())
"""
self.learned_reward_fn = learned_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k
self.entropy_stats = RunningNorm(1)
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

# These two need to be set with set_replay_buffer():
self.replay_buffer_view: Optional[ReplayBufferView] = None
self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)

def set_replay_buffer(
self,
replay_buffer: ReplayBufferView,
obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
):
self.replay_buffer_view = replay_buffer
self.obs_shape = obs_shape

def unsupervised_exploration_finish(self):
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING

def __call__(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
return self._entropy_reward(state, action, next_state, done)
else:
return self.learned_reward_fn(state, action, next_state, done)

def _entropy_reward(self, state, action, next_state, done):
if self.replay_buffer_view is None:
raise ValueError(
"Replay buffer must be supplied before entropy reward can be used",
)
all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *self.obs_shape))

if all_observations.shape[0] < self.nearest_neighbor_k:
# not enough observations to compare to, fall back to the learned function;
# (falling back to a constant may also be ok)
return self.learned_reward_fn(state, action, next_state, done)
else:
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
th.tensor(state),
th.tensor(all_observations),
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(entropies)
return normalized_entropies.numpy()

def __getstate__(self):
state = self.__dict__.copy()
del state["replay_buffer_view"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.replay_buffer_view = None
139 changes: 123 additions & 16 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tqdm.auto import tqdm

from imitation.algorithms import base
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import rollout, types, wrappers
from imitation.data.types import (
AnyPath,
Expand All @@ -44,6 +45,7 @@
from imitation.policies import exploration_wrapper
from imitation.regularization import regularizers
from imitation.rewards import reward_function, reward_nets, reward_wrapper
from imitation.rewards.reward_function import RewardFn
from imitation.util import logger as imit_logger
from imitation.util import networks, util

Expand Down Expand Up @@ -75,6 +77,40 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
be the environment rewards, not ones from a reward model).
""" # noqa: DAR202

@property
def has_pretraining(self) -> bool:
"""Indicates whether this generator has a pre-training phase.

The value can be used, e.g., when allocating time-steps for pre-training.

By default, True is returned if the unsupervised_pretrain() method is not
overridden, bud subclasses may choose to override this behavior.

Returns:
True if this generator has a pre-training phase, False otherwise
"""
orig_impl = TrajectoryGenerator.unsupervised_pretrain
return type(self).unsupervised_pretrain != orig_impl

def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
"""Pre-train an agent before collecting comparisons.

By default, this method asserts that pre-training has zero steps allocated.
Override this behavior in subclasses that implement pre-training.

Args:
steps: number of environment steps to train for.
**kwargs: additional keyword arguments to pass on to
the training procedure.
"""
if steps > 0:
self._logger.warn(
f"{steps} timesteps allocated for unsupervised pre-training:"
" Trajectory generators without pre-training implementation should"
" not consume any timesteps (otherwise the total number of"
" timesteps executed may be misleading)",
)

def train(self, steps: int, **kwargs: Any) -> None:
"""Train an agent if the trajectory generator uses one.

Expand Down Expand Up @@ -165,7 +201,7 @@ def __init__(
reward_fn.action_space,
)
reward_fn = reward_fn.predict_processed
self.reward_fn = reward_fn
self.reward_fn: RewardFn = reward_fn
self.exploration_frac = exploration_frac
self.rng = rng

Expand Down Expand Up @@ -316,6 +352,43 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:
self.algorithm.set_logger(self.logger)


class PebbleAgentTrainer(AgentTrainer):
"""Specialization of AgentTrainer for PEBBLE training.

Includes unsupervised pretraining with an entropy based reward function.
"""

reward_fn: PebbleStateEntropyReward

def __init__(
self,
*,
reward_fn: PebbleStateEntropyReward,
**kwargs,
) -> None:
"""Builds PebbleAgentTrainer.

Args:
reward_fn: Pebble reward function
**kwargs: additional keyword arguments to pass on to the parent class

Raises:
ValueError: Unexpected type of reward_fn given.
"""
if not isinstance(reward_fn, PebbleStateEntropyReward):
raise ValueError(
f"{self.__class__.__name__} expects "
f"{PebbleStateEntropyReward.__name__} reward function",
)
super().__init__(reward_fn=reward_fn, **kwargs)

def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
self.train(steps, **kwargs)
fn = self.reward_fn
assert isinstance(fn, PebbleStateEntropyReward)
fn.unsupervised_exploration_finish()


def _get_trajectories(
trajectories: Sequence[TrajectoryWithRew],
steps: int,
Expand Down Expand Up @@ -1495,6 +1568,7 @@ def __init__(
transition_oversampling: float = 1,
initial_comparison_frac: float = 0.1,
initial_epoch_multiplier: float = 200.0,
unsupervised_agent_pretrain_frac: float = 0.05,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
allow_variable_horizon: bool = False,
rng: Optional[np.random.Generator] = None,
Expand Down Expand Up @@ -1544,6 +1618,9 @@ def __init__(
initial_epoch_multiplier: before agent training begins, train the reward
model for this many more epochs than usual (on fragments sampled from a
random agent).
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
agent will be trained without preference gathering (and reward model
training)
custom_logger: Where to log to; if None (default), creates a new logger.
allow_variable_horizon: If False (default), algorithm will raise an
exception if it detects trajectories of different length during
Expand Down Expand Up @@ -1642,6 +1719,7 @@ def __init__(
self.fragment_length = fragment_length
self.initial_comparison_frac = initial_comparison_frac
self.initial_epoch_multiplier = initial_epoch_multiplier
self.unsupervised_agent_pretrain_frac = unsupervised_agent_pretrain_frac
self.num_iterations = num_iterations
self.transition_oversampling = transition_oversampling
if callable(query_schedule):
Expand Down Expand Up @@ -1670,25 +1748,31 @@ def train(
A dictionary with final metrics such as loss and accuracy
of the reward model.
"""
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
total_comparisons -= initial_comparisons

# Compute the number of comparisons to request at each iteration in advance.
vec_schedule = np.vectorize(self.query_schedule)
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
print(f"Query schedule: {schedule}")

timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps,
self.num_iterations,
)
preference_query_schedule = self._preference_gather_schedule(total_comparisons)
self.logger.log(f"Query schedule: {preference_query_schedule}")

(
unsup_pretrain_timesteps,
timesteps_per_iteration,
extra_timesteps,
) = self._compute_timesteps(total_timesteps)
reward_loss = None
reward_accuracy = None

for i, num_pairs in enumerate(schedule):
###################################################
# Pre-training agent before gathering preferences #
###################################################
if unsup_pretrain_timesteps:
with self.logger.accumulate_means("agent"):
self.logger.log(
f"Pre-training agent for {unsup_pretrain_timesteps} timesteps",
)
self.trajectory_generator.unsupervised_pretrain(
unsup_pretrain_timesteps,
)

for i, num_pairs in enumerate(preference_query_schedule):
##########################
# Gather new preferences #
##########################
Expand Down Expand Up @@ -1751,3 +1835,26 @@ def train(
self._iteration += 1

return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy}

def _preference_gather_schedule(self, total_comparisons):
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
total_comparisons -= initial_comparisons
vec_schedule = np.vectorize(self.query_schedule)
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
return schedule

def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
if self.trajectory_generator.has_pretraining:
unsupervised_pretrain_timesteps = int(
total_timesteps * self.unsupervised_agent_pretrain_frac,
)
else:
unsupervised_pretrain_timesteps = 0
timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps - unsupervised_pretrain_timesteps,
self.num_iterations,
)
return unsupervised_pretrain_timesteps, timesteps_per_iteration, extra_timesteps
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class SAC1024Policy(sac_policies.SACPolicy):
"""Actor and value networks with two hidden layers of 1024 units respectively.

This matches the implementation of SAC policies in the PEBBLE paper. See:
https://arxiv.org/pdf/2106.05091.pdf
https://arxiv.org/abs/2106.05091
https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml

Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units
Expand Down
Loading