Skip to content

Commit 189af59

Browse files
author
Jan Michelfeit
committed
#625 add test for pebble agent trainer
1 parent f3decf1 commit 189af59

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/imitation/rewards/reward_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
4040
@abc.abstractmethod
4141
def on_replay_buffer_initialized(
4242
self,
43-
replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined]
43+
replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined] # noqa
4444
):
4545
pass

tests/algorithms/test_preference_comparisons.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import imitation.testing.reward_nets as testing_reward_nets
1818
from imitation.algorithms import preference_comparisons
19+
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1920
from imitation.data import types
2021
from imitation.data.types import TrajectoryWithRew
22+
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
2123
from imitation.regularization import regularizers, updaters
2224
from imitation.rewards import reward_nets
2325
from imitation.util import networks, util
@@ -71,6 +73,23 @@ def agent_trainer(agent, reward_net, venv, rng):
7173
return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng)
7274

7375

76+
@pytest.fixture
77+
def replay_buffer(rng):
78+
return ReplayBufferView(rng.random((10, 8, 4)), lambda: slice(None))
79+
80+
81+
@pytest.fixture
82+
def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
83+
reward_fn = PebbleStateEntropyReward(reward_net.predict_processed)
84+
reward_fn.set_replay_buffer(replay_buffer, (4,))
85+
return preference_comparisons.PebbleAgentTrainer(
86+
algorithm=agent,
87+
reward_fn=reward_fn,
88+
venv=venv,
89+
rng=rng,
90+
)
91+
92+
7493
def _check_trajs_equal(
7594
trajs1: Sequence[types.TrajectoryWithRew],
7695
trajs2: Sequence[types.TrajectoryWithRew],
@@ -277,14 +296,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
277296
"schedule",
278297
["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)],
279298
)
299+
@pytest.mark.parametrize("agent_fixture", ["agent_trainer", "pebble_agent_trainer"])
280300
def test_trainer_no_crash(
281-
agent_trainer,
301+
request,
302+
agent_fixture,
282303
reward_net,
283304
random_fragmenter,
284305
custom_logger,
285306
schedule,
286307
rng,
287308
):
309+
agent_trainer = request.getfixturevalue(agent_fixture)
288310
main_trainer = preference_comparisons.PreferenceComparisons(
289311
agent_trainer,
290312
reward_net,

0 commit comments

Comments
 (0)