Skip to content

Commit 3d7cfca

Browse files
author
Jan Michelfeit
committed
#625 introduce ReplayBufferAwareRewardFn
1 parent d1aae17 commit 3d7cfca

File tree

5 files changed

+35
-19
lines changed

5 files changed

+35
-19
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from typing import Tuple
2+
13
import numpy as np
24
import torch as th
35
from gym.vector.utils import spaces
46
from stable_baselines3.common.preprocessing import get_obs_shape
57

6-
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
7-
from imitation.rewards.reward_function import RewardFn
8+
from imitation.policies.replay_buffer_wrapper import (
9+
ReplayBufferView,
10+
ReplayBufferRewardWrapper,
11+
)
12+
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
813
from imitation.util import util
914
from imitation.util.networks import RunningNorm
1015

1116

12-
class StateEntropyReward(RewardFn):
17+
class StateEntropyReward(ReplayBufferAwareRewardFn):
1318
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
1419
self.nearest_neighbor_k = nearest_neighbor_k
1520
# TODO support n_envs > 1
@@ -20,8 +25,12 @@ def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
2025
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
2126
)
2227

23-
def set_replay_buffer(self, replay_buffer: ReplayBufferView):
28+
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
29+
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)
30+
31+
def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape:Tuple):
2432
self.replay_buffer_view = replay_buffer
33+
self.obs_shape = obs_shape
2534

2635
def __call__(
2736
self,

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stable_baselines3.common.buffers import ReplayBuffer
88
from stable_baselines3.common.type_aliases import ReplayBufferSamples
99

10-
from imitation.rewards.reward_function import RewardFn
10+
from imitation.rewards.reward_function import RewardFn, ReplayBufferAwareRewardFn
1111
from imitation.util import util
1212

1313

@@ -37,13 +37,13 @@ def __init__(
3737
observations_buffer: np.ndarray,
3838
buffer_slice_provider: Callable[[], slice],
3939
):
40-
self._observations_buffer = observations_buffer.view()
41-
self._observations_buffer.flags.writeable = False
40+
self._observations_buffer_view = observations_buffer.view()
41+
self._observations_buffer_view.flags.writeable = False
4242
self._buffer_slice_provider = buffer_slice_provider
4343

4444
@property
4545
def observations(self):
46-
return self._observations_buffer[self._buffer_slice_provider()]
46+
return self._observations_buffer_view[self._buffer_slice_provider()]
4747

4848

4949
class ReplayBufferRewardWrapper(ReplayBuffer):
@@ -57,7 +57,6 @@ def __init__(
5757
*,
5858
replay_buffer_class: Type[ReplayBuffer],
5959
reward_fn: RewardFn,
60-
on_initialized_callback: Callable[["ReplayBufferRewardWrapper"], None] = None,
6160
**kwargs,
6261
):
6362
"""Builds ReplayBufferRewardWrapper.
@@ -88,8 +87,8 @@ def __init__(
8887
self.reward_fn = reward_fn
8988
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
9089
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)
91-
if on_initialized_callback is not None:
92-
on_initialized_callback(self)
90+
if isinstance(reward_fn, ReplayBufferAwareRewardFn):
91+
reward_fn.on_replay_buffer_initialized(self)
9392

9493
# TODO(juan) remove the type ignore once the merged PR
9594
# https://github.com/python/mypy/pull/13475

src/imitation/rewards/reward_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ def __call__(
3232
Returns:
3333
Computed rewards of shape `(batch_size,`).
3434
""" # noqa: DAR202
35+
36+
37+
class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
38+
@abc.abstractmethod
39+
def on_replay_buffer_initialized(self, replay_buffer: "ReplayBufferRewardWrapper"):
40+
pass

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def test_state_entropy_reward_returns_entropy(rng):
2323
obs_shape = get_obs_shape(SPACE)
2424
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))
2525

26+
2627
reward_fn = StateEntropyReward(K, SPACE)
27-
reward_fn.set_buffer_view(ReplayBufferView(all_observations, lambda: slice(None)))
28+
reward_fn.set_replay_buffer(ReplayBufferView(all_observations, lambda: slice(None)), obs_shape)
2829

2930
# Act
3031
observations = rng.random((BATCH_SIZE, *obs_shape))
@@ -48,7 +49,8 @@ def test_state_entropy_reward_returns_normalized_values():
4849
reward_fn = StateEntropyReward(K, SPACE)
4950
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
5051
reward_fn.set_replay_buffer(
51-
ReplayBufferView(all_observations, lambda: slice(None))
52+
ReplayBufferView(all_observations, lambda: slice(None)),
53+
get_obs_shape(SPACE)
5254
)
5355

5456
dim = 8
@@ -79,7 +81,7 @@ def test_state_entropy_reward_can_pickle():
7981

8082
obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
8183
reward_fn = StateEntropyReward(K, SPACE)
82-
reward_fn.set_replay_buffer(replay_buffer)
84+
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
8385
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
8486

8587
# Act

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from stable_baselines3.common.save_util import load_from_pkl
1818

1919
from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper
20+
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
2021
from imitation.util import util
2122

2223

@@ -175,16 +176,15 @@ def test_replay_buffer_view_provides_buffered_observations():
175176
np.testing.assert_allclose(view.observations, expected)
176177

177178

178-
def test_replay_buffer_reward_wrapper_calls_initialization_callback_with_itself():
179-
callback = Mock()
179+
def test_replay_buffer_reward_wrapper_calls_reward_initialization_callback():
180+
reward_fn = Mock(spec=ReplayBufferAwareRewardFn)
180181
buffer = ReplayBufferRewardWrapper(
181182
10,
182183
spaces.Discrete(2),
183184
spaces.Discrete(2),
184185
replay_buffer_class=ReplayBuffer,
185-
reward_fn=Mock(),
186+
reward_fn=reward_fn,
186187
n_envs=2,
187188
handle_timeout_termination=False,
188-
on_initialized_callback=callback,
189189
)
190-
assert callback.call_args.args[0] is buffer
190+
assert reward_fn.on_replay_buffer_initialized.call_args.args[0] is buffer

0 commit comments

Comments
 (0)