File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
src/imitation/algorithms/pebble Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -35,13 +35,14 @@ def __call__(
35
35
36
36
all_observations = self .replay_buffer_view .observations
37
37
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
38
- all_observations = all_observations .reshape ((- 1 , * self .obs_shape ))
38
+ all_observations = all_observations .reshape ((- 1 , * state .shape [1 :])) # TODO #625: fix self.obs_shape
39
+ # TODO #625: deal with the conversion back and forth between np and torch
39
40
entropies = util .compute_state_entropy (
40
- state ,
41
- all_observations ,
41
+ th . tensor ( state ) ,
42
+ th . tensor ( all_observations ) ,
42
43
self .nearest_neighbor_k ,
43
44
)
44
- normalized_entropies = self .entropy_stats .forward (th . as_tensor ( entropies ) )
45
+ normalized_entropies = self .entropy_stats .forward (entropies )
45
46
return normalized_entropies .numpy ()
46
47
47
48
def __getstate__ (self ):
You can’t perform that action at this time.
0 commit comments