Skip to content

Commit ddd7b2f

Browse files
author
Jan Michelfeit
committed
#625 entropy_reward can automatically detect if enough observations are present
1 parent 88371e1 commit ddd7b2f

File tree

2 files changed

+53
-73
lines changed

2 files changed

+53
-73
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,28 @@
1616
class PebbleRewardPhase(Enum):
1717
"""States representing different behaviors for PebbleStateEntropyReward"""
1818

19-
# Collecting samples so that we have something for entropy calculation
20-
LEARNING_START = auto()
21-
# Entropy based reward
22-
UNSUPERVISED_EXPLORATION = auto()
23-
# Learned reward
24-
POLICY_AND_REWARD_LEARNING = auto()
19+
UNSUPERVISED_EXPLORATION = auto() # Entropy based reward
20+
POLICY_AND_REWARD_LEARNING = auto() # Learned reward
2521

2622

2723
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
2824
"""
2925
Reward function for implementation of the PEBBLE learning algorithm
3026
(https://arxiv.org/pdf/2106.05091.pdf).
3127
32-
The rewards returned by this function go through the three phases
33-
defined in PebbleRewardPhase. To transition between these phases,
34-
unsupervised_exploration_start() and unsupervised_exploration_finish()
35-
need to be called.
28+
The rewards returned by this function go through the three phases:
29+
1. Before enough samples are collected for entropy calculation, the
30+
underlying function is returned. This shouldn't matter because
31+
OffPolicyAlgorithms have an initialization period for `learning_starts`
32+
timesteps.
33+
2. During the unsupervised exploration phase, entropy based reward is returned
34+
3. After unsupervised exploration phase is finished, the underlying learned
35+
reward is returned.
3636
37-
The second phase (UNSUPERVISED_EXPLORATION) also requires that a buffer
38-
with observations to compare against is supplied with set_replay_buffer()
39-
or on_replay_buffer_initialized().
37+
The second phase requires that a buffer with observations to compare against is
38+
supplied with set_replay_buffer() or on_replay_buffer_initialized().
39+
To transition to the last phase, unsupervised_exploration_finish() needs
40+
to be called.
4041
4142
Args:
4243
learned_reward_fn: The learned reward function used after unsupervised
@@ -51,11 +52,10 @@ def __init__(
5152
learned_reward_fn: RewardFn,
5253
nearest_neighbor_k: int = 5,
5354
):
54-
self.trained_reward_fn = learned_reward_fn
55+
self.learned_reward_fn = learned_reward_fn
5556
self.nearest_neighbor_k = nearest_neighbor_k
56-
# TODO support n_envs > 1
5757
self.entropy_stats = RunningNorm(1)
58-
self.state = PebbleRewardPhase.LEARNING_START
58+
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
5959

6060
# These two need to be set with set_replay_buffer():
6161
self.replay_buffer_view = None
@@ -68,10 +68,6 @@ def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
6868
self.replay_buffer_view = replay_buffer
6969
self.obs_shape = obs_shape
7070

71-
def unsupervised_exploration_start(self):
72-
assert self.state == PebbleRewardPhase.LEARNING_START
73-
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
74-
7571
def unsupervised_exploration_finish(self):
7672
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
7773
self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING
@@ -84,26 +80,30 @@ def __call__(
8480
done: np.ndarray,
8581
) -> np.ndarray:
8682
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
87-
return self._entropy_reward(state)
83+
return self._entropy_reward(state, action, next_state, done)
8884
else:
89-
return self.trained_reward_fn(state, action, next_state, done)
85+
return self.learned_reward_fn(state, action, next_state, done)
9086

91-
def _entropy_reward(self, state):
87+
def _entropy_reward(self, state, action, next_state, done):
9288
if self.replay_buffer_view is None:
9389
raise ValueError(
9490
"Replay buffer must be supplied before entropy reward can be used"
9591
)
96-
9792
all_observations = self.replay_buffer_view.observations
9893
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
9994
all_observations = all_observations.reshape((-1, *self.obs_shape))
100-
# TODO #625: deal with the conversion back and forth between np and torch
101-
entropies = util.compute_state_entropy(
102-
th.tensor(state),
103-
th.tensor(all_observations),
104-
self.nearest_neighbor_k,
105-
)
106-
normalized_entropies = self.entropy_stats.forward(entropies)
95+
96+
if all_observations.shape[0] < self.nearest_neighbor_k:
97+
# not enough observations to compare to, fall back to the learned function
98+
return self.learned_reward_fn(state, action, next_state, done)
99+
else:
100+
# TODO #625: deal with the conversion back and forth between np and torch
101+
entropies = util.compute_state_entropy(
102+
th.tensor(state),
103+
th.tensor(all_observations),
104+
self.nearest_neighbor_k,
105+
)
106+
normalized_entropies = self.entropy_stats.forward(entropies)
107107
return normalized_entropies.numpy()
108108

109109
def __getstate__(self):

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,13 @@
2020
VENVS = 2
2121

2222

23-
def test_pebble_entropy_reward_function_returns_learned_reward_initially():
24-
expected_reward = np.ones(1)
25-
learned_reward_mock = Mock()
26-
learned_reward_mock.return_value = expected_reward
27-
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)
28-
29-
# Act
30-
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
31-
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
32-
33-
# Assert
34-
assert reward == expected_reward
35-
learned_reward_mock.assert_called_once_with(
36-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
37-
)
38-
39-
40-
def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training():
41-
expected_reward = np.ones(1)
42-
learned_reward_mock = Mock()
43-
learned_reward_mock.return_value = expected_reward
44-
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)
45-
# move all the way to the last state
46-
reward_fn.unsupervised_exploration_start()
47-
reward_fn.unsupervised_exploration_finish()
48-
49-
# Act
50-
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
51-
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
52-
53-
# Assert
54-
assert reward == expected_reward
55-
learned_reward_mock.assert_called_once_with(
56-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
57-
)
58-
59-
6023
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
6124
all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE)))
6225

63-
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
26+
reward_fn = PebbleStateEntropyReward(Mock(), K)
6427
reward_fn.set_replay_buffer(
6528
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
6629
)
67-
reward_fn.unsupervised_exploration_start()
6830

6931
# Act
7032
observations = th.rand((BATCH_SIZE, *(OBS_SHAPE)))
@@ -85,13 +47,12 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
8547
# mock entropy computation so that we can test only stats collection in this test
8648
m.side_effect = lambda obs, all_obs, k: obs
8749

88-
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
50+
reward_fn = PebbleStateEntropyReward(Mock(), K)
8951
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
9052
reward_fn.set_replay_buffer(
9153
ReplayBufferView(all_observations, lambda: slice(None)),
9254
OBS_SHAPE,
9355
)
94-
reward_fn.unsupervised_exploration_start()
9556

9657
dim = 8
9758
shift = 3
@@ -115,12 +76,31 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
11576
)
11677

11778

79+
def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training():
80+
expected_reward = np.ones(1)
81+
learned_reward_mock = Mock()
82+
learned_reward_mock.return_value = expected_reward
83+
reward_fn = PebbleStateEntropyReward(learned_reward_mock)
84+
# move all the way to the last state
85+
reward_fn.unsupervised_exploration_finish()
86+
87+
# Act
88+
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
89+
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
90+
91+
# Assert
92+
assert reward == expected_reward
93+
learned_reward_mock.assert_called_once_with(
94+
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
95+
)
96+
97+
11898
def test_pebble_entropy_reward_can_pickle():
11999
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
120100
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
121101

122102
obs1 = np.random.rand(VENVS, *OBS_SHAPE)
123-
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, K)
103+
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
124104
reward_fn.set_replay_buffer(replay_buffer, OBS_SHAPE)
125105
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
126106

0 commit comments

Comments
 (0)