1
+ import pickle
1
2
from unittest .mock import patch
2
3
3
4
import numpy as np
@@ -33,7 +34,9 @@ def test_state_entropy_reward_returns_entropy(rng):
33
34
expected = util .compute_state_entropy (
34
35
observations , all_observations .reshape (- 1 , * obs_shape ), K
35
36
)
36
- expected_normalized = reward_fn .entropy_stats .normalize (th .as_tensor (expected )).numpy ()
37
+ expected_normalized = reward_fn .entropy_stats .normalize (
38
+ th .as_tensor (expected )
39
+ ).numpy ()
37
40
np .testing .assert_allclose (reward , expected_normalized )
38
41
39
42
@@ -44,7 +47,7 @@ def test_state_entropy_reward_returns_normalized_values():
44
47
45
48
reward_fn = StateEntropyReward (K , SPACE )
46
49
all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
47
- reward_fn .set_buffer_view (
50
+ reward_fn .set_replay_buffer (
48
51
ReplayBufferView (all_observations , lambda : slice (None ))
49
52
)
50
53
@@ -68,3 +71,24 @@ def test_state_entropy_reward_returns_normalized_values():
68
71
rtol = 0.05 ,
69
72
atol = 0.05 ,
70
73
)
74
+
75
+
76
+ def test_state_entropy_reward_can_pickle ():
77
+ all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
78
+ replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
79
+
80
+ obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
81
+ reward_fn = StateEntropyReward (K , SPACE )
82
+ reward_fn .set_replay_buffer (replay_buffer )
83
+ reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
84
+
85
+ # Act
86
+ pickled = pickle .dumps (reward_fn )
87
+ reward_fn_deserialized = pickle .loads (pickled )
88
+ reward_fn_deserialized .set_replay_buffer (replay_buffer )
89
+
90
+ # Assert
91
+ obs2 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
92
+ expected_result = reward_fn (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
93
+ actual_result = reward_fn_deserialized (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
94
+ np .testing .assert_allclose (actual_result , expected_result )
0 commit comments