Skip to content

Commit d1aae17

Browse files
author
Jan Michelfeit
committed
#625 remove ReplayBufferEntropyRewardWrapper
1 parent ec7b853 commit d1aae17

File tree

5 files changed

+11
-182
lines changed

5 files changed

+11
-182
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def __call__(
3535

3636
all_observations = self.replay_buffer_view.observations
3737
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
38-
all_observations = all_observations.reshape((-1, *state.shape[1:])) # TODO #625: fix self.obs_shape
38+
all_observations = all_observations.reshape(
39+
(-1, *state.shape[1:]) # TODO #625: fix self.obs_shape
40+
)
3941
# TODO #625: deal with the conversion back and forth between np and torch
4042
entropies = util.compute_state_entropy(
4143
th.tensor(state),

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 1 addition & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Wrapper for reward labeling for transitions sampled from a replay buffer."""
22

3-
from typing import Callable
4-
from typing import Mapping, Type
3+
from typing import Callable, Mapping, Type
54

65
import numpy as np
76
from gym import spaces
@@ -10,7 +9,6 @@
109

1110
from imitation.rewards.reward_function import RewardFn
1211
from imitation.util import util
13-
from imitation.util.networks import RunningNorm
1412

1513

1614
def _samples_to_reward_fn_input(
@@ -143,83 +141,3 @@ def _get_samples(self):
143141
"_get_samples() is intentionally not implemented."
144142
"This method should not be called.",
145143
)
146-
147-
148-
class ReplayBufferEntropyRewardWrapper(ReplayBufferRewardWrapper):
149-
"""Relabel the rewards from a ReplayBuffer, initially using entropy as reward."""
150-
151-
def __init__(
152-
self,
153-
buffer_size: int,
154-
observation_space: spaces.Space,
155-
action_space: spaces.Space,
156-
*,
157-
replay_buffer_class: Type[ReplayBuffer],
158-
reward_fn: RewardFn,
159-
entropy_as_reward_samples: int,
160-
k: int = 5,
161-
**kwargs,
162-
):
163-
"""Builds ReplayBufferRewardWrapper.
164-
165-
Args:
166-
buffer_size: Max number of elements in the buffer
167-
observation_space: Observation space
168-
action_space: Action space
169-
replay_buffer_class: Class of the replay buffer.
170-
reward_fn: Reward function for reward relabeling.
171-
entropy_as_reward_samples: Number of samples to use entropy as the reward,
172-
before switching to using the reward_fn for relabeling.
173-
k: Use the k'th nearest neighbor's distance when computing state entropy.
174-
**kwargs: keyword arguments for ReplayBuffer.
175-
"""
176-
# TODO should we limit by number of batches (as this does)
177-
# or number of observations returned?
178-
super().__init__(
179-
buffer_size,
180-
observation_space,
181-
action_space,
182-
replay_buffer_class=replay_buffer_class,
183-
reward_fn=reward_fn,
184-
**kwargs,
185-
)
186-
self.sample_count = 0
187-
self.k = k
188-
# TODO support n_envs > 1
189-
self.entropy_stats = RunningNorm(1)
190-
self.entropy_as_reward_samples = entropy_as_reward_samples
191-
192-
def sample(self, *args, **kwargs):
193-
self.sample_count += 1
194-
samples = super().sample(*args, **kwargs)
195-
# For some reason self.entropy_as_reward_samples seems to get cleared,
196-
# and I have no idea why.
197-
if self.sample_count > self.entropy_as_reward_samples:
198-
return samples
199-
# TODO we really ought to reset the reward network once we are done w/
200-
# the entropy based pre-training. We also have no reason to train
201-
# or even use the reward network before then.
202-
203-
if self.full:
204-
all_obs = self.observations
205-
else:
206-
all_obs = self.observations[: self.pos]
207-
# super().sample() flattens the venv dimension, let's do it too
208-
all_obs = all_obs.reshape((-1, *self.obs_shape))
209-
entropies = util.compute_state_entropy(
210-
samples.observations,
211-
all_obs,
212-
self.k,
213-
)
214-
215-
# Normalize to have mean of 0 and standard deviation of 1 according to running stats
216-
entropies = self.entropy_stats.forward(entropies)
217-
assert entropies.shape == samples.rewards.shape
218-
219-
return ReplayBufferSamples(
220-
observations=samples.observations,
221-
actions=samples.actions,
222-
next_observations=samples.next_observations,
223-
dones=samples.dones,
224-
rewards=entropies,
225-
)

src/imitation/scripts/common/rl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ def _maybe_add_relabel_buffer(
8686
"""Use ReplayBufferRewardWrapper in rl_kwargs if relabel_reward_fn is not None."""
8787
rl_kwargs = dict(rl_kwargs)
8888
if relabel_reward_fn:
89-
_buffer_kwargs = dict(reward_fn=relabel_reward_fn)
90-
_buffer_kwargs["replay_buffer_class"] = rl_kwargs.get(
91-
"replay_buffer_class",
92-
buffers.ReplayBuffer,
89+
_buffer_kwargs = dict(
90+
reward_fn=relabel_reward_fn,
91+
replay_buffer_class=rl_kwargs.get(
92+
"replay_buffer_class", buffers.ReplayBuffer
93+
),
9394
)
9495
rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper
9596

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 2 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@
1313
from stable_baselines3.common import buffers, off_policy_algorithm, policies
1414
from stable_baselines3.common.buffers import ReplayBuffer
1515
from stable_baselines3.common.policies import BasePolicy
16-
from stable_baselines3.common.preprocessing import get_obs_shape, get_action_dim
16+
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
1717
from stable_baselines3.common.save_util import load_from_pkl
18-
from stable_baselines3.common.vec_env import DummyVecEnv
1918

20-
from imitation.policies.replay_buffer_wrapper import (
21-
ReplayBufferEntropyRewardWrapper,
22-
ReplayBufferRewardWrapper,
23-
)
19+
from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper
2420
from imitation.util import util
2521

2622

@@ -123,54 +119,6 @@ def test_wrapper_class(tmpdir, rng):
123119
replay_buffer_wrapper._get_samples()
124120

125121

126-
# Combine this with the above test via parameterization over the buffer class
127-
def test_entropy_wrapper_class_no_op(tmpdir, rng):
128-
buffer_size = 15
129-
total_timesteps = 20
130-
entropy_samples = 0
131-
132-
venv = util.make_vec_env("Pendulum-v1", n_envs=1, rng=rng)
133-
rl_algo = sb3.SAC(
134-
policy=sb3.sac.policies.SACPolicy,
135-
policy_kwargs=dict(),
136-
env=venv,
137-
seed=42,
138-
replay_buffer_class=ReplayBufferEntropyRewardWrapper,
139-
replay_buffer_kwargs=dict(
140-
replay_buffer_class=buffers.ReplayBuffer,
141-
reward_fn=zero_reward_fn,
142-
entropy_as_reward_samples=entropy_samples,
143-
),
144-
buffer_size=buffer_size,
145-
)
146-
147-
rl_algo.learn(total_timesteps=total_timesteps)
148-
149-
buffer_path = osp.join(tmpdir, "buffer.pkl")
150-
rl_algo.save_replay_buffer(buffer_path)
151-
replay_buffer_wrapper = load_from_pkl(buffer_path)
152-
replay_buffer = replay_buffer_wrapper.replay_buffer
153-
154-
# replay_buffer_wrapper.sample(...) should return zero-reward transitions
155-
assert buffer_size == replay_buffer_wrapper.size() == replay_buffer.size()
156-
assert (replay_buffer_wrapper.sample(total_timesteps).rewards == 0.0).all()
157-
assert (replay_buffer.sample(total_timesteps).rewards != 0.0).all() # seed=42
158-
159-
# replay_buffer_wrapper.pos, replay_buffer_wrapper.full
160-
assert replay_buffer_wrapper.pos == total_timesteps - buffer_size
161-
assert replay_buffer_wrapper.full
162-
163-
# reset()
164-
replay_buffer_wrapper.reset()
165-
assert 0 == replay_buffer_wrapper.size() == replay_buffer.size()
166-
assert replay_buffer_wrapper.pos == 0
167-
assert not replay_buffer_wrapper.full
168-
169-
# to_torch()
170-
tensor = replay_buffer_wrapper.to_torch(np.ones(42))
171-
assert type(tensor) is th.Tensor
172-
173-
174122
class ActionIsObsEnv(gym.Env):
175123
"""Simple environment where the obs is the action."""
176124

@@ -191,45 +139,6 @@ def reset(self):
191139
return np.array([0])
192140

193141

194-
def test_entropy_wrapper_class(tmpdir, rng):
195-
buffer_size = 20
196-
entropy_samples = 500
197-
k = 4
198-
199-
venv = DummyVecEnv([ActionIsObsEnv])
200-
rl_algo = sb3.SAC(
201-
policy=sb3.sac.policies.SACPolicy,
202-
policy_kwargs=dict(),
203-
env=venv,
204-
seed=42,
205-
replay_buffer_class=ReplayBufferEntropyRewardWrapper,
206-
replay_buffer_kwargs=dict(
207-
replay_buffer_class=buffers.ReplayBuffer,
208-
reward_fn=zero_reward_fn,
209-
entropy_as_reward_samples=entropy_samples,
210-
k=k,
211-
),
212-
buffer_size=buffer_size,
213-
)
214-
215-
rl_algo.learn(total_timesteps=buffer_size)
216-
initial_entropy = util.compute_state_entropy(
217-
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
218-
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
219-
k=k,
220-
)
221-
222-
rl_algo.learn(total_timesteps=entropy_samples - buffer_size)
223-
# Expect that the entropy of our replay buffer is now higher,
224-
# since we trained with that as the reward.
225-
trained_entropy = util.compute_state_entropy(
226-
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
227-
th.Tensor(rl_algo.replay_buffer.replay_buffer.observations),
228-
k=k,
229-
)
230-
assert trained_entropy.mean() > initial_entropy.mean()
231-
232-
233142
def test_replay_buffer_view_provides_buffered_observations():
234143
space = spaces.Box(np.array([0]), np.array([5]))
235144
n_envs = 2

tests/util/test_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,3 @@ def test_compute_state_entropy_2d():
144144
util.compute_state_entropy(obs, all_obs, k=3),
145145
np.sqrt(20**2 + 2**2),
146146
)
147-

0 commit comments

Comments
 (0)