Skip to content

Commit 152efa6

Browse files
author
Jan Michelfeit
committed
#625 specialized PebbleAgentTrainer to distinguish from old preference comparison trainer
1 parent 716c710 commit 152efa6

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tqdm.auto import tqdm
3434

3535
from imitation.algorithms import base
36+
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
3637
from imitation.data import rollout, types, wrappers
3738
from imitation.data.types import (
3839
AnyPath,
@@ -329,6 +330,27 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:
329330
self.algorithm.set_logger(self.logger)
330331

331332

333+
class PebbleAgentTrainer(AgentTrainer):
334+
"""
335+
Specialization of AgentTrainer for PEBBLE training.
336+
Includes unsupervised pretraining with an entropy based reward function.
337+
"""
338+
339+
reward_fn: PebbleStateEntropyReward
340+
341+
def __init__(
342+
self,
343+
*,
344+
reward_fn: PebbleStateEntropyReward,
345+
**kwargs,
346+
) -> None:
347+
super().__init__(reward_fn=reward_fn, **kwargs)
348+
349+
def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
350+
self.train(steps, **kwargs)
351+
self.reward_fn.unsupervised_exploration_finish()
352+
353+
332354
def _get_trajectories(
333355
trajectories: Sequence[TrajectoryWithRew],
334356
steps: int,
@@ -1705,7 +1727,9 @@ def train(
17051727
self.logger.log(
17061728
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
17071729
)
1708-
self.trajectory_generator.unsupervised_pretrain(unsupervised_pretrain_timesteps)
1730+
self.trajectory_generator.unsupervised_pretrain(
1731+
unsupervised_pretrain_timesteps
1732+
)
17091733

17101734
for i, num_pairs in enumerate(preference_query_schedule):
17111735
##########################

0 commit comments

Comments
 (0)