Skip to content

Commit 15c682a

Browse files
author
Jan Michelfeit
committed
#625 fix entropy shape
1 parent ddd7b2f commit 15c682a

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def _entropy_reward(self, state, action, next_state, done):
9494
all_observations = all_observations.reshape((-1, *self.obs_shape))
9595

9696
if all_observations.shape[0] < self.nearest_neighbor_k:
97-
# not enough observations to compare to, fall back to the learned function
97+
# not enough observations to compare to, fall back to the learned function;
98+
# (falling back to a constant may also be ok)
9899
return self.learned_reward_fn(state, action, next_state, done)
99100
else:
100101
# TODO #625: deal with the conversion back and forth between np and torch
@@ -104,7 +105,7 @@ def _entropy_reward(self, state, action, next_state, done):
104105
self.nearest_neighbor_k,
105106
)
106107
normalized_entropies = self.entropy_stats.forward(entropies)
107-
return normalized_entropies.numpy()
108+
return normalized_entropies.numpy()
108109

109110
def __getstate__(self):
110111
state = self.__dict__.copy()

src/imitation/util/util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,5 +389,4 @@ def compute_state_entropy(
389389
# a point is itself, which we want to skip.
390390
assert distances_tensor.shape[-1] > k
391391
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
392-
state_entropy = knn_dists
393-
return state_entropy.unsqueeze(1)
392+
return knn_dists

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@
2121

2222

2323
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
24-
all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE)))
24+
all_observations = rng.random((BUFFER_SIZE, VENVS, *OBS_SHAPE))
2525

2626
reward_fn = PebbleStateEntropyReward(Mock(), K)
2727
reward_fn.set_replay_buffer(
2828
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
2929
)
3030

3131
# Act
32-
observations = th.rand((BATCH_SIZE, *(OBS_SHAPE)))
32+
observations = th.rand((BATCH_SIZE, *OBS_SHAPE))
3333
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
3434

3535
# Assert
3636
expected = util.compute_state_entropy(
37-
observations, all_observations.reshape(-1, *(OBS_SHAPE)), K
37+
observations, all_observations.reshape(-1, *OBS_SHAPE), K
3838
)
3939
expected_normalized = reward_fn.entropy_stats.normalize(
4040
th.as_tensor(expected)

0 commit comments

Comments
 (0)