16
16
class PebbleRewardPhase (Enum ):
17
17
"""States representing different behaviors for PebbleStateEntropyReward"""
18
18
19
- # Collecting samples so that we have something for entropy calculation
20
- LEARNING_START = auto ()
21
- # Entropy based reward
22
- UNSUPERVISED_EXPLORATION = auto ()
23
- # Learned reward
24
- POLICY_AND_REWARD_LEARNING = auto ()
19
+ UNSUPERVISED_EXPLORATION = auto () # Entropy based reward
20
+ POLICY_AND_REWARD_LEARNING = auto () # Learned reward
25
21
26
22
27
23
class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
28
24
"""
29
25
Reward function for implementation of the PEBBLE learning algorithm
30
26
(https://arxiv.org/pdf/2106.05091.pdf).
31
27
32
- The rewards returned by this function go through the three phases
33
- defined in PebbleRewardPhase. To transition between these phases,
34
- unsupervised_exploration_start() and unsupervised_exploration_finish()
35
- need to be called.
28
+ The rewards returned by this function go through the three phases:
29
+ 1. Before enough samples are collected for entropy calculation, the
30
+ underlying function is returned. This shouldn't matter because
31
+ OffPolicyAlgorithms have an initialization period for `learning_starts`
32
+ timesteps.
33
+ 2. During the unsupervised exploration phase, entropy based reward is returned
34
+ 3. After unsupervised exploration phase is finished, the underlying learned
35
+ reward is returned.
36
36
37
- The second phase (UNSUPERVISED_EXPLORATION) also requires that a buffer
38
- with observations to compare against is supplied with set_replay_buffer()
39
- or on_replay_buffer_initialized().
37
+ The second phase requires that a buffer with observations to compare against is
38
+ supplied with set_replay_buffer() or on_replay_buffer_initialized().
39
+ To transition to the last phase, unsupervised_exploration_finish() needs
40
+ to be called.
40
41
41
42
Args:
42
43
learned_reward_fn: The learned reward function used after unsupervised
@@ -51,11 +52,10 @@ def __init__(
51
52
learned_reward_fn : RewardFn ,
52
53
nearest_neighbor_k : int = 5 ,
53
54
):
54
- self .trained_reward_fn = learned_reward_fn
55
+ self .learned_reward_fn = learned_reward_fn
55
56
self .nearest_neighbor_k = nearest_neighbor_k
56
- # TODO support n_envs > 1
57
57
self .entropy_stats = RunningNorm (1 )
58
- self .state = PebbleRewardPhase .LEARNING_START
58
+ self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
59
59
60
60
# These two need to be set with set_replay_buffer():
61
61
self .replay_buffer_view = None
@@ -68,10 +68,6 @@ def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
68
68
self .replay_buffer_view = replay_buffer
69
69
self .obs_shape = obs_shape
70
70
71
- def unsupervised_exploration_start (self ):
72
- assert self .state == PebbleRewardPhase .LEARNING_START
73
- self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
74
-
75
71
def unsupervised_exploration_finish (self ):
76
72
assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
77
73
self .state = PebbleRewardPhase .POLICY_AND_REWARD_LEARNING
@@ -84,26 +80,30 @@ def __call__(
84
80
done : np .ndarray ,
85
81
) -> np .ndarray :
86
82
if self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION :
87
- return self ._entropy_reward (state )
83
+ return self ._entropy_reward (state , action , next_state , done )
88
84
else :
89
- return self .trained_reward_fn (state , action , next_state , done )
85
+ return self .learned_reward_fn (state , action , next_state , done )
90
86
91
- def _entropy_reward (self , state ):
87
+ def _entropy_reward (self , state , action , next_state , done ):
92
88
if self .replay_buffer_view is None :
93
89
raise ValueError (
94
90
"Replay buffer must be supplied before entropy reward can be used"
95
91
)
96
-
97
92
all_observations = self .replay_buffer_view .observations
98
93
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
99
94
all_observations = all_observations .reshape ((- 1 , * self .obs_shape ))
100
- # TODO #625: deal with the conversion back and forth between np and torch
101
- entropies = util .compute_state_entropy (
102
- th .tensor (state ),
103
- th .tensor (all_observations ),
104
- self .nearest_neighbor_k ,
105
- )
106
- normalized_entropies = self .entropy_stats .forward (entropies )
95
+
96
+ if all_observations .shape [0 ] < self .nearest_neighbor_k :
97
+ # not enough observations to compare to, fall back to the learned function
98
+ return self .learned_reward_fn (state , action , next_state , done )
99
+ else :
100
+ # TODO #625: deal with the conversion back and forth between np and torch
101
+ entropies = util .compute_state_entropy (
102
+ th .tensor (state ),
103
+ th .tensor (all_observations ),
104
+ self .nearest_neighbor_k ,
105
+ )
106
+ normalized_entropies = self .entropy_stats .forward (entropies )
107
107
return normalized_entropies .numpy ()
108
108
109
109
def __getstate__ (self ):
0 commit comments