6
6
from gym .spaces import Discrete
7
7
from stable_baselines3 .common .preprocessing import get_obs_shape
8
8
9
- from imitation .algorithms .pebble .entropy_reward import StateEntropyReward
9
+ from imitation .algorithms .pebble .entropy_reward import PebbleStateEntropyReward
10
10
from imitation .policies .replay_buffer_wrapper import ReplayBufferView
11
11
from imitation .util import util
12
12
@@ -24,7 +24,7 @@ def test_state_entropy_reward_returns_entropy(rng):
24
24
all_observations = rng .random ((BUFFER_SIZE , VENVS , * obs_shape ))
25
25
26
26
27
- reward_fn = StateEntropyReward (K , SPACE )
27
+ reward_fn = PebbleStateEntropyReward (K , SPACE )
28
28
reward_fn .set_replay_buffer (ReplayBufferView (all_observations , lambda : slice (None )), obs_shape )
29
29
30
30
# Act
@@ -46,7 +46,7 @@ def test_state_entropy_reward_returns_normalized_values():
46
46
# mock entropy computation so that we can test only stats collection in this test
47
47
m .side_effect = lambda obs , all_obs , k : obs
48
48
49
- reward_fn = StateEntropyReward (K , SPACE )
49
+ reward_fn = PebbleStateEntropyReward (K , SPACE )
50
50
all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
51
51
reward_fn .set_replay_buffer (
52
52
ReplayBufferView (all_observations , lambda : slice (None )),
@@ -80,7 +80,7 @@ def test_state_entropy_reward_can_pickle():
80
80
replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
81
81
82
82
obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
83
- reward_fn = StateEntropyReward (K , SPACE )
83
+ reward_fn = PebbleStateEntropyReward (K , SPACE )
84
84
reward_fn .set_replay_buffer (replay_buffer , get_obs_shape (SPACE ))
85
85
reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
86
86
0 commit comments