Skip to content

Commit 1f50696

Browse files
author
Jan Michelfeit
committed
#625 make entropy reward serializable with pickle
1 parent 2dec99f commit 1f50696

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
1414
self.nearest_neighbor_k = nearest_neighbor_k
1515
# TODO support n_envs > 1
1616
self.entropy_stats = RunningNorm(1)
17+
self.observation_space = observation_space
1718
self.obs_shape = get_obs_shape(observation_space)
1819
self.replay_buffer_view = ReplayBufferView(
1920
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
2021
)
2122

22-
def set_buffer_view(self, replay_buffer_view: ReplayBufferView):
23-
self.replay_buffer_view = replay_buffer_view
23+
def set_replay_buffer(self, replay_buffer: ReplayBufferView):
24+
self.replay_buffer_view = replay_buffer
2425

2526
def __call__(
2627
self,
@@ -42,3 +43,14 @@ def __call__(
4243
)
4344
normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies))
4445
return normalized_entropies.numpy()
46+
47+
def __getstate__(self):
48+
state = self.__dict__.copy()
49+
del state["replay_buffer_view"]
50+
return state
51+
52+
def __setstate__(self, state):
53+
self.__dict__.update(state)
54+
self.replay_buffer_view = ReplayBufferView(
55+
np.empty(0, self.observation_space.dtype), lambda: slice(0)
56+
)

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from imitation.rewards.reward_function import RewardFn
1111
from imitation.util import util
1212
from imitation.util.networks import RunningNorm
13+
from typing import Callable
1314

1415

1516
def _samples_to_reward_fn_input(

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
from unittest.mock import patch
23

34
import numpy as np
@@ -33,7 +34,9 @@ def test_state_entropy_reward_returns_entropy(rng):
3334
expected = util.compute_state_entropy(
3435
observations, all_observations.reshape(-1, *obs_shape), K
3536
)
36-
expected_normalized = reward_fn.entropy_stats.normalize(th.as_tensor(expected)).numpy()
37+
expected_normalized = reward_fn.entropy_stats.normalize(
38+
th.as_tensor(expected)
39+
).numpy()
3740
np.testing.assert_allclose(reward, expected_normalized)
3841

3942

@@ -44,7 +47,7 @@ def test_state_entropy_reward_returns_normalized_values():
4447

4548
reward_fn = StateEntropyReward(K, SPACE)
4649
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
47-
reward_fn.set_buffer_view(
50+
reward_fn.set_replay_buffer(
4851
ReplayBufferView(all_observations, lambda: slice(None))
4952
)
5053

@@ -68,3 +71,24 @@ def test_state_entropy_reward_returns_normalized_values():
6871
rtol=0.05,
6972
atol=0.05,
7073
)
74+
75+
76+
def test_state_entropy_reward_can_pickle():
77+
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
78+
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
79+
80+
obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
81+
reward_fn = StateEntropyReward(K, SPACE)
82+
reward_fn.set_replay_buffer(replay_buffer)
83+
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
84+
85+
# Act
86+
pickled = pickle.dumps(reward_fn)
87+
reward_fn_deserialized = pickle.loads(pickled)
88+
reward_fn_deserialized.set_replay_buffer(replay_buffer)
89+
90+
# Assert
91+
obs2 = np.random.rand(VENVS, *get_obs_shape(SPACE))
92+
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
93+
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
94+
np.testing.assert_allclose(actual_result, expected_result)

0 commit comments

Comments
 (0)