Skip to content

Commit 4fd0758

Browse files
author
Jan Michelfeit
committed
#641 increase coverage
1 parent 74ba96b commit 4fd0758

File tree

6 files changed

+54
-27
lines changed

6 files changed

+54
-27
lines changed

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,12 @@ class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
143143
def on_replay_buffer_initialized(
144144
self,
145145
replay_buffer: ReplayBufferRewardWrapper,
146-
):
147-
pass
146+
) -> None:
147+
"""Hook method to be called when ReplayBuffer is initialized.
148+
149+
Needed to propagate the ReplayBuffer to a reward function because the buffer
150+
is created indirectly in ReplayBufferRewardWrapper.
151+
152+
Args:
153+
replay_buffer: the created ReplayBuffer
154+
""" # noqa: DAR202

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,6 @@ def test_entropy_reward_net_can_pickle(rng):
155155
np.testing.assert_allclose(actual_result, expected_result)
156156

157157

158-
def reward_fn_stub(state, action, next_state, done):
159-
return state
160-
161-
162158
def replay_buffer_mock(all_observations: np.ndarray, obs_space: Space = SPACE) -> Mock:
163159
buffer_view = ReplayBufferView(all_observations, lambda: slice(None))
164160
mock = Mock()

tests/algorithms/test_preference_comparisons.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@
1818

1919
import imitation.testing.reward_nets as testing_reward_nets
2020
from imitation.algorithms import preference_comparisons
21+
from imitation.algorithms.preference_comparisons import (
22+
PebbleAgentTrainer,
23+
TrajectoryGenerator,
24+
)
2125
from imitation.data import types
2226
from imitation.data.types import TrajectoryWithRew
2327
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
2428
from imitation.regularization import regularizers, updaters
2529
from imitation.rewards import reward_nets
30+
from imitation.rewards.reward_function import RewardFn
2631
from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn
2732
from imitation.util import networks, util
2833

@@ -1120,3 +1125,26 @@ def test_that_trainer_improves(
11201125
)
11211126

11221127
assert np.mean(trained_agent_rewards) > np.mean(novice_agent_rewards)
1128+
1129+
1130+
def test_trajectory_generator_raises_on_pretrain_if_not_implemented():
1131+
class TrajectoryGeneratorTestImpl(TrajectoryGenerator):
1132+
def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
1133+
return []
1134+
1135+
generator = TrajectoryGeneratorTestImpl()
1136+
assert generator.has_pretraining is False
1137+
with pytest.raises(ValueError, match="should not consume any timesteps"):
1138+
generator.unsupervised_pretrain(1)
1139+
1140+
1141+
def test_pebble_agent_trainer_expects_pebble_reward(agent, venv, rng):
1142+
reward_fn: RewardFn = lambda state, action, next, done: state
1143+
1144+
with pytest.raises(ValueError, match="PebbleStateEntropyReward"):
1145+
PebbleAgentTrainer(
1146+
algorithm=agent,
1147+
reward_fn=reward_fn, # type: ignore[call-arg]
1148+
venv=venv,
1149+
rng=rng,
1150+
)

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Type
55
from unittest.mock import Mock
66

7-
import gym
87
import numpy as np
98
import pytest
109
import stable_baselines3 as sb3
@@ -122,26 +121,6 @@ def test_wrapper_class(tmpdir, rng):
122121
replay_buffer_wrapper._get_samples()
123122

124123

125-
class ActionIsObsEnv(gym.Env):
126-
"""Simple environment where the obs is the action."""
127-
128-
def __init__(self):
129-
"""Initialize environment."""
130-
super().__init__()
131-
self.action_space = spaces.Box(np.array([0]), np.array([1]))
132-
self.observation_space = spaces.Box(np.array([0]), np.array([1]))
133-
134-
def step(self, action):
135-
obs = action
136-
reward = 0
137-
done = False
138-
info = {}
139-
return obs, reward, done, info
140-
141-
def reset(self):
142-
return np.array([0])
143-
144-
145124
def test_replay_buffer_view_provides_buffered_observations():
146125
space = spaces.Box(np.array([0]), np.array([5]))
147126
n_envs = 2

tests/scripts/test_scripts.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,20 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs)
254254
assert isinstance(run.result, dict)
255255

256256

257+
def test_train_preference_comparisons_pebble_config(tmpdir):
258+
config_updates = dict(common=dict(log_root=tmpdir))
259+
run = train_preference_comparisons.train_preference_comparisons_ex.run(
260+
# make sure rl.sac named_config is called after rl.fast to overwrite
261+
# rl_kwargs.batch_size to None
262+
named_configs=ALGO_FAST_CONFIGS["preference_comparison"]
263+
+ ["pebble", "mountain_car_continuous"],
264+
config_updates=config_updates,
265+
)
266+
assert run.config["rl"]["rl_cls"] is stable_baselines3.SAC
267+
assert run.status == "COMPLETED"
268+
assert isinstance(run.result, dict)
269+
270+
257271
def test_train_dagger_main(tmpdir):
258272
with pytest.warns(None) as record:
259273
run = train_imitation.train_imitation_ex.run(

tests/scripts/test_train_preference_comparisons.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def test_creates_normalized_entropy_pebble_reward():
5252
atol=0.05,
5353
)
5454

55+
# Just to make coverage happy:
56+
reward_fn_stub(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
57+
5558

5659
def reward_fn_stub(state, action, next_state, done):
5760
return state

0 commit comments

Comments
 (0)