Skip to content

Commit 27b8a55

Browse files
author
Jan Michelfeit
committed
#625 fix assumptions about shapes in ReplayBufferEntropyRewardWrapper
1 parent 0a435bc commit 27b8a55

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,26 +167,22 @@ def sample(self, *args, **kwargs):
167167
all_obs = self.observations
168168
else:
169169
all_obs = self.observations[: self.pos]
170+
# super().sample() flattens the venv dimension, let's do it too
171+
all_obs = all_obs.reshape((-1, *self.obs_shape))
170172
entropies = util.compute_state_entropy(
171-
# TODO support multiple environments
172-
samples.observations.unsqueeze(1),
173-
all_obs,
173+
samples.observations,
174+
all_obs.reshape((-1, *self.obs_shape)),
174175
self.k,
175176
)
176177

177178
# Normalize to have mean of 0 and standard deviation of 1 according to running stats
178179
entropies = self.entropy_stats.forward(entropies)
179-
180-
entropies_th = (
181-
util.safe_to_tensor(entropies)
182-
.reshape(samples.rewards.shape)
183-
.to(samples.rewards.device)
184-
)
180+
assert entropies.shape == samples.rewards.shape
185181

186182
return ReplayBufferSamples(
187-
samples.observations,
188-
samples.actions,
189-
samples.next_observations,
190-
samples.dones,
191-
entropies_th,
183+
observations=samples.observations,
184+
actions=samples.actions,
185+
next_observations=samples.next_observations,
186+
dones=samples.dones,
187+
rewards=entropies,
192188
)

src/imitation/scripts/config/train_preference_comparisons_pebble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def train_sac():
5050

5151

5252
@common.common_ingredient.config
53-
def mountain_car():
53+
def common_mountain_car_continuous():
5454
env_name = "MountainCarContinuous-v0"
5555
locals() # quieten flake8
5656

0 commit comments

Comments
 (0)