|
33 | 33 | from tqdm.auto import tqdm
|
34 | 34 |
|
35 | 35 | from imitation.algorithms import base
|
| 36 | +from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward |
36 | 37 | from imitation.data import rollout, types, wrappers
|
37 | 38 | from imitation.data.types import (
|
38 | 39 | AnyPath,
|
@@ -329,6 +330,27 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:
|
329 | 330 | self.algorithm.set_logger(self.logger)
|
330 | 331 |
|
331 | 332 |
|
| 333 | +class PebbleAgentTrainer(AgentTrainer): |
| 334 | + """ |
| 335 | + Specialization of AgentTrainer for PEBBLE training. |
| 336 | + Includes unsupervised pretraining with an entropy based reward function. |
| 337 | + """ |
| 338 | + |
| 339 | + reward_fn: PebbleStateEntropyReward |
| 340 | + |
| 341 | + def __init__( |
| 342 | + self, |
| 343 | + *, |
| 344 | + reward_fn: PebbleStateEntropyReward, |
| 345 | + **kwargs, |
| 346 | + ) -> None: |
| 347 | + super().__init__(reward_fn=reward_fn, **kwargs) |
| 348 | + |
| 349 | + def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None: |
| 350 | + self.train(steps, **kwargs) |
| 351 | + self.reward_fn.unsupervised_exploration_finish() |
| 352 | + |
| 353 | + |
332 | 354 | def _get_trajectories(
|
333 | 355 | trajectories: Sequence[TrajectoryWithRew],
|
334 | 356 | steps: int,
|
@@ -1705,7 +1727,9 @@ def train(
|
1705 | 1727 | self.logger.log(
|
1706 | 1728 | f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
|
1707 | 1729 | )
|
1708 |
| - self.trajectory_generator.unsupervised_pretrain(unsupervised_pretrain_timesteps) |
| 1730 | + self.trajectory_generator.unsupervised_pretrain( |
| 1731 | + unsupervised_pretrain_timesteps |
| 1732 | + ) |
1709 | 1733 |
|
1710 | 1734 | for i, num_pairs in enumerate(preference_query_schedule):
|
1711 | 1735 | ##########################
|
|
0 commit comments