Skip to content

Commit d348534

Browse files
author
Jan Michelfeit
committed
#625 rename PebbleStateEntropyReward
1 parent 3d7cfca commit d348534

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from imitation.util.networks import RunningNorm
1515

1616

17-
class StateEntropyReward(ReplayBufferAwareRewardFn):
17+
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
18+
# TODO #625: get rid of the observation_space parameter
1819
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
1920
self.nearest_neighbor_k = nearest_neighbor_k
2021
# TODO support n_envs > 1

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from gym.spaces import Discrete
77
from stable_baselines3.common.preprocessing import get_obs_shape
88

9-
from imitation.algorithms.pebble.entropy_reward import StateEntropyReward
9+
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1010
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
1111
from imitation.util import util
1212

@@ -24,7 +24,7 @@ def test_state_entropy_reward_returns_entropy(rng):
2424
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))
2525

2626

27-
reward_fn = StateEntropyReward(K, SPACE)
27+
reward_fn = PebbleStateEntropyReward(K, SPACE)
2828
reward_fn.set_replay_buffer(ReplayBufferView(all_observations, lambda: slice(None)), obs_shape)
2929

3030
# Act
@@ -46,7 +46,7 @@ def test_state_entropy_reward_returns_normalized_values():
4646
# mock entropy computation so that we can test only stats collection in this test
4747
m.side_effect = lambda obs, all_obs, k: obs
4848

49-
reward_fn = StateEntropyReward(K, SPACE)
49+
reward_fn = PebbleStateEntropyReward(K, SPACE)
5050
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
5151
reward_fn.set_replay_buffer(
5252
ReplayBufferView(all_observations, lambda: slice(None)),
@@ -80,7 +80,7 @@ def test_state_entropy_reward_can_pickle():
8080
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
8181

8282
obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
83-
reward_fn = StateEntropyReward(K, SPACE)
83+
reward_fn = PebbleStateEntropyReward(K, SPACE)
8484
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
8585
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
8686

0 commit comments

Comments
 (0)