18
18
from imitation .util .networks import RunningNorm
19
19
20
20
21
- class PebbleRewardPhase (enum .Enum ):
22
- """States representing different behaviors for PebbleStateEntropyReward."""
23
-
24
- UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
25
- POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
26
-
27
-
28
21
class InsufficientObservations (RuntimeError ):
29
22
pass
30
23
31
24
32
- class EntropyRewardNet (RewardNet ):
25
+ class EntropyRewardNet (RewardNet , ReplayBufferAwareRewardFn ):
33
26
def __init__ (
34
27
self ,
35
28
nearest_neighbor_k : int ,
36
- replay_buffer_view : ReplayBufferView ,
37
29
observation_space : gym .Space ,
38
30
action_space : gym .Space ,
39
31
normalize_images : bool = True ,
32
+ replay_buffer_view : Optional [ReplayBufferView ] = None ,
40
33
):
41
34
"""Initialize the RewardNet.
42
35
43
36
Args:
37
+ nearest_neighbor_k: Parameter for entropy computation (see
38
+ compute_state_entropy())
44
39
observation_space: the observation space of the environment
45
40
action_space: the action space of the environment
46
41
normalize_images: whether to automatically normalize
47
42
image observations to [0, 1] (from 0 to 255). Defaults to True.
43
+ replay_buffer_view: Replay buffer view with observations to compare
44
+ against when computing entropy. If None is given, the buffer needs to
45
+ be set with on_replay_buffer_initialized() before EntropyRewardNet can
46
+ be used
48
47
"""
49
48
super ().__init__ (observation_space , action_space , normalize_images )
50
49
self .nearest_neighbor_k = nearest_neighbor_k
51
50
self ._replay_buffer_view = replay_buffer_view
52
51
53
- def set_replay_buffer (self , replay_buffer : ReplayBufferRewardWrapper ):
54
- """This method needs to be called after unpickling .
52
+ def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
53
+ """Sets replay buffer .
55
54
56
- See also __getstate__() / __setstate__()
55
+ This method needs to be called, e.g., after unpickling.
56
+ See also __getstate__() / __setstate__().
57
57
"""
58
58
assert self .observation_space == replay_buffer .observation_space
59
59
assert self .action_space == replay_buffer .action_space
@@ -111,6 +111,13 @@ def __setstate__(self, state):
111
111
self ._replay_buffer_view = None
112
112
113
113
114
+ class PebbleRewardPhase (enum .Enum ):
115
+ """States representing different behaviors for PebbleStateEntropyReward."""
116
+
117
+ UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
118
+ POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
119
+
120
+
114
121
class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
115
122
"""Reward function for implementation of the PEBBLE learning algorithm.
116
123
@@ -126,14 +133,15 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
126
133
reward is returned.
127
134
128
135
The second phase requires that a buffer with observations to compare against is
129
- supplied with set_replay_buffer() or on_replay_buffer_initialized().
130
- To transition to the last phase, unsupervised_exploration_finish() needs
131
- to be called.
136
+ supplied with on_replay_buffer_initialized(). To transition to the last phase,
137
+ unsupervised_exploration_finish() needs to be called.
132
138
"""
133
139
134
140
def __init__ (
135
141
self ,
136
142
learned_reward_fn : RewardFn ,
143
+ observation_space : gym .Space ,
144
+ action_space : gym .Space ,
137
145
nearest_neighbor_k : int = 5 ,
138
146
):
139
147
"""Builds this class.
@@ -146,28 +154,20 @@ def __init__(
146
154
"""
147
155
self .learned_reward_fn = learned_reward_fn
148
156
self .nearest_neighbor_k = nearest_neighbor_k
149
-
150
157
self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
151
158
152
- # These two need to be set with set_replay_buffer():
153
- self ._entropy_reward_net : Optional [EntropyRewardNet ] = None
154
- self ._normalized_entropy_reward_net : Optional [RewardNet ] = None
159
+ self ._entropy_reward_net = EntropyRewardNet (
160
+ nearest_neighbor_k = self .nearest_neighbor_k ,
161
+ observation_space = observation_space ,
162
+ action_space = action_space ,
163
+ normalize_images = False ,
164
+ )
165
+ self ._normalized_entropy_reward_net = NormalizedRewardNet (
166
+ self ._entropy_reward_net , RunningNorm
167
+ )
155
168
156
169
def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
157
- if self ._normalized_entropy_reward_net is None :
158
- self ._entropy_reward_net = EntropyRewardNet (
159
- nearest_neighbor_k = self .nearest_neighbor_k ,
160
- replay_buffer_view = replay_buffer .buffer_view ,
161
- observation_space = replay_buffer .observation_space ,
162
- action_space = replay_buffer .action_space ,
163
- normalize_images = False ,
164
- )
165
- self ._normalized_entropy_reward_net = NormalizedRewardNet (
166
- self ._entropy_reward_net , RunningNorm
167
- )
168
- else :
169
- assert self ._entropy_reward_net is not None
170
- self ._entropy_reward_net .set_replay_buffer (replay_buffer )
170
+ self ._entropy_reward_net .on_replay_buffer_initialized (replay_buffer )
171
171
172
172
def unsupervised_exploration_finish (self ):
173
173
assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
0 commit comments