Skip to content

Commit a3369d4

Browse files
author
Jan Michelfeit
committed
#641 code review: refactor PebbleStateEntropyReward so that inner RewardNets are initialized in constructor
1 parent 50577b0 commit a3369d4

File tree

4 files changed

+43
-40
lines changed

4 files changed

+43
-40
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,42 @@
1818
from imitation.util.networks import RunningNorm
1919

2020

21-
class PebbleRewardPhase(enum.Enum):
22-
"""States representing different behaviors for PebbleStateEntropyReward."""
23-
24-
UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
25-
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward
26-
27-
2821
class InsufficientObservations(RuntimeError):
2922
pass
3023

3124

32-
class EntropyRewardNet(RewardNet):
25+
class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
3326
def __init__(
3427
self,
3528
nearest_neighbor_k: int,
36-
replay_buffer_view: ReplayBufferView,
3729
observation_space: gym.Space,
3830
action_space: gym.Space,
3931
normalize_images: bool = True,
32+
replay_buffer_view: Optional[ReplayBufferView] = None,
4033
):
4134
"""Initialize the RewardNet.
4235
4336
Args:
37+
nearest_neighbor_k: Parameter for entropy computation (see
38+
compute_state_entropy())
4439
observation_space: the observation space of the environment
4540
action_space: the action space of the environment
4641
normalize_images: whether to automatically normalize
4742
image observations to [0, 1] (from 0 to 255). Defaults to True.
43+
replay_buffer_view: Replay buffer view with observations to compare
44+
against when computing entropy. If None is given, the buffer needs to
45+
be set with on_replay_buffer_initialized() before EntropyRewardNet can
46+
be used
4847
"""
4948
super().__init__(observation_space, action_space, normalize_images)
5049
self.nearest_neighbor_k = nearest_neighbor_k
5150
self._replay_buffer_view = replay_buffer_view
5251

53-
def set_replay_buffer(self, replay_buffer: ReplayBufferRewardWrapper):
54-
"""This method needs to be called after unpickling.
52+
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
53+
"""Sets replay buffer.
5554
56-
See also __getstate__() / __setstate__()
55+
This method needs to be called, e.g., after unpickling.
56+
See also __getstate__() / __setstate__().
5757
"""
5858
assert self.observation_space == replay_buffer.observation_space
5959
assert self.action_space == replay_buffer.action_space
@@ -111,6 +111,13 @@ def __setstate__(self, state):
111111
self._replay_buffer_view = None
112112

113113

114+
class PebbleRewardPhase(enum.Enum):
115+
"""States representing different behaviors for PebbleStateEntropyReward."""
116+
117+
UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
118+
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward
119+
120+
114121
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
115122
"""Reward function for implementation of the PEBBLE learning algorithm.
116123
@@ -126,14 +133,15 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
126133
reward is returned.
127134
128135
The second phase requires that a buffer with observations to compare against is
129-
supplied with set_replay_buffer() or on_replay_buffer_initialized().
130-
To transition to the last phase, unsupervised_exploration_finish() needs
131-
to be called.
136+
supplied with on_replay_buffer_initialized(). To transition to the last phase,
137+
unsupervised_exploration_finish() needs to be called.
132138
"""
133139

134140
def __init__(
135141
self,
136142
learned_reward_fn: RewardFn,
143+
observation_space: gym.Space,
144+
action_space: gym.Space,
137145
nearest_neighbor_k: int = 5,
138146
):
139147
"""Builds this class.
@@ -146,28 +154,20 @@ def __init__(
146154
"""
147155
self.learned_reward_fn = learned_reward_fn
148156
self.nearest_neighbor_k = nearest_neighbor_k
149-
150157
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
151158

152-
# These two need to be set with set_replay_buffer():
153-
self._entropy_reward_net: Optional[EntropyRewardNet] = None
154-
self._normalized_entropy_reward_net: Optional[RewardNet] = None
159+
self._entropy_reward_net = EntropyRewardNet(
160+
nearest_neighbor_k=self.nearest_neighbor_k,
161+
observation_space=observation_space,
162+
action_space=action_space,
163+
normalize_images=False,
164+
)
165+
self._normalized_entropy_reward_net = NormalizedRewardNet(
166+
self._entropy_reward_net, RunningNorm
167+
)
155168

156169
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
157-
if self._normalized_entropy_reward_net is None:
158-
self._entropy_reward_net = EntropyRewardNet(
159-
nearest_neighbor_k=self.nearest_neighbor_k,
160-
replay_buffer_view=replay_buffer.buffer_view,
161-
observation_space=replay_buffer.observation_space,
162-
action_space=replay_buffer.action_space,
163-
normalize_images=False,
164-
)
165-
self._normalized_entropy_reward_net = NormalizedRewardNet(
166-
self._entropy_reward_net, RunningNorm
167-
)
168-
else:
169-
assert self._entropy_reward_net is not None
170-
self._entropy_reward_net.set_replay_buffer(replay_buffer)
170+
self._entropy_reward_net.on_replay_buffer_initialized(replay_buffer)
171171

172172
def unsupervised_exploration_finish(self):
173173
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def make_reward_function(
7474
if pebble_enabled:
7575
relabel_reward_fn = PebbleStateEntropyReward(
7676
relabel_reward_fn, # type: ignore[assignment]
77-
pebble_nearest_neighbor_k,
77+
observation_space=reward_net.observation_space,
78+
action_space=reward_net.action_space,
79+
nearest_neighbor_k=pebble_nearest_neighbor_k,
7880
)
7981
return relabel_reward_fn
8082

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import numpy as np
77
import torch as th
8-
from gym.spaces import Discrete, Box
8+
from gym.spaces import Box
99
from gym.spaces.space import Space
10+
1011
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1112
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
1213
from imitation.util import util
@@ -23,7 +24,7 @@
2324
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
2425
all_observations = rng.random((BUFFER_SIZE, VENVS) + SPACE.shape)
2526

26-
reward_fn = PebbleStateEntropyReward(Mock(), K)
27+
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, SPACE, K)
2728
reward_fn.on_replay_buffer_initialized(
2829
replay_buffer_mock(
2930
ReplayBufferView(all_observations, lambda: slice(None)),
@@ -50,7 +51,7 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
5051
# only stats collection in this test
5152
m.side_effect = lambda obs, all_obs, k: obs
5253

53-
reward_fn = PebbleStateEntropyReward(Mock(), K)
54+
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, SPACE, K)
5455
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
5556
reward_fn.on_replay_buffer_initialized(
5657
replay_buffer_mock(
@@ -88,7 +89,7 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
8889
expected_reward = np.ones(1)
8990
learned_reward_mock = Mock()
9091
learned_reward_mock.return_value = expected_reward
91-
reward_fn = PebbleStateEntropyReward(learned_reward_mock)
92+
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE, SPACE)
9293
# move all the way to the last state
9394
reward_fn.unsupervised_exploration_finish()
9495

@@ -111,7 +112,7 @@ def test_pebble_entropy_reward_can_pickle():
111112
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
112113

113114
obs1 = np.random.rand(VENVS, *SPACE.shape)
114-
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
115+
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, SPACE, K)
115116
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, SPACE))
116117
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
117118

tests/algorithms/test_preference_comparisons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
8585
replay_buffer_mock = Mock()
8686
replay_buffer_mock.buffer_view = replay_buffer
8787
replay_buffer_mock.obs_shape = (4,)
88-
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed)
88+
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed, venv.observation_space, venv.action_space)
8989
reward_fn.on_replay_buffer_initialized(replay_buffer_mock)
9090
return preference_comparisons.PebbleAgentTrainer(
9191
algorithm=agent,

0 commit comments

Comments
 (0)