|
16 | 16 |
|
17 | 17 | import imitation.testing.reward_nets as testing_reward_nets
|
18 | 18 | from imitation.algorithms import preference_comparisons
|
| 19 | +from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward |
19 | 20 | from imitation.data import types
|
20 | 21 | from imitation.data.types import TrajectoryWithRew
|
| 22 | +from imitation.policies.replay_buffer_wrapper import ReplayBufferView |
21 | 23 | from imitation.regularization import regularizers, updaters
|
22 | 24 | from imitation.rewards import reward_nets
|
23 | 25 | from imitation.util import networks, util
|
@@ -71,6 +73,23 @@ def agent_trainer(agent, reward_net, venv, rng):
|
71 | 73 | return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng)
|
72 | 74 |
|
73 | 75 |
|
| 76 | +@pytest.fixture |
| 77 | +def replay_buffer(rng): |
| 78 | + return ReplayBufferView(rng.random((10, 8, 4)), lambda: slice(None)) |
| 79 | + |
| 80 | + |
| 81 | +@pytest.fixture |
| 82 | +def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer): |
| 83 | + reward_fn = PebbleStateEntropyReward(reward_net.predict_processed) |
| 84 | + reward_fn.set_replay_buffer(replay_buffer, (4,)) |
| 85 | + return preference_comparisons.PebbleAgentTrainer( |
| 86 | + algorithm=agent, |
| 87 | + reward_fn=reward_fn, |
| 88 | + venv=venv, |
| 89 | + rng=rng, |
| 90 | + ) |
| 91 | + |
| 92 | + |
74 | 93 | def _check_trajs_equal(
|
75 | 94 | trajs1: Sequence[types.TrajectoryWithRew],
|
76 | 95 | trajs2: Sequence[types.TrajectoryWithRew],
|
@@ -277,14 +296,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
|
277 | 296 | "schedule",
|
278 | 297 | ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)],
|
279 | 298 | )
|
| 299 | +@pytest.mark.parametrize("agent_fixture", ["agent_trainer", "pebble_agent_trainer"]) |
280 | 300 | def test_trainer_no_crash(
|
281 |
| - agent_trainer, |
| 301 | + request, |
| 302 | + agent_fixture, |
282 | 303 | reward_net,
|
283 | 304 | random_fragmenter,
|
284 | 305 | custom_logger,
|
285 | 306 | schedule,
|
286 | 307 | rng,
|
287 | 308 | ):
|
| 309 | + agent_trainer = request.getfixturevalue(agent_fixture) |
288 | 310 | main_trainer = preference_comparisons.PreferenceComparisons(
|
289 | 311 | agent_trainer,
|
290 | 312 | reward_net,
|
|
0 commit comments