Skip to content

Commit ec7b853

Browse files
author
Jan Michelfeit
committed
#625 fix entropy_reward.py
1 parent ad29c34 commit ec7b853

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ def __call__(
3535

3636
all_observations = self.replay_buffer_view.observations
3737
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
38-
all_observations = all_observations.reshape((-1, *self.obs_shape))
38+
all_observations = all_observations.reshape((-1, *state.shape[1:])) # TODO #625: fix self.obs_shape
39+
# TODO #625: deal with the conversion back and forth between np and torch
3940
entropies = util.compute_state_entropy(
40-
state,
41-
all_observations,
41+
th.tensor(state),
42+
th.tensor(all_observations),
4243
self.nearest_neighbor_k,
4344
)
44-
normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies))
45+
normalized_entropies = self.entropy_stats.forward(entropies)
4546
return normalized_entropies.numpy()
4647

4748
def __getstate__(self):

0 commit comments

Comments
 (0)