Skip to content

Commit f957baf

Browse files
author
Jan Michelfeit
committed
#625 add optional pretraining to PreferenceComparisons
1 parent 9090b0c commit f957baf

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
7575
be the environment rewards, not ones from a reward model).
7676
""" # noqa: DAR202
7777

78+
def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
79+
"""Pre-train an agent if the trajectory generator uses one that
80+
needs pre-training.
81+
82+
By default, this method does nothing and doesn't need
83+
to be overridden in subclasses that don't require pre-training.
84+
85+
Args:
86+
steps: number of environment steps to train for.
87+
**kwargs: additional keyword arguments to pass on to
88+
the training procedure.
89+
"""
90+
7891
def train(self, steps: int, **kwargs: Any) -> None:
7992
"""Train an agent if the trajectory generator uses one.
8093
@@ -1493,7 +1506,7 @@ def __init__(
14931506
transition_oversampling: float = 1,
14941507
initial_comparison_frac: float = 0.1,
14951508
initial_epoch_multiplier: float = 200.0,
1496-
initial_agent_pretrain_frac: float = 0.01,
1509+
initial_agent_pretrain_frac: float = 0.05,
14971510
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
14981511
allow_variable_horizon: bool = False,
14991512
rng: Optional[np.random.Generator] = None,
@@ -1685,6 +1698,15 @@ def train(
16851698
reward_loss = None
16861699
reward_accuracy = None
16871700

1701+
###################################################
1702+
# Pre-training agent before gathering preferences #
1703+
###################################################
1704+
with self.logger.accumulate_means("agent"):
1705+
self.logger.log(
1706+
f"Pre-training agent for {agent_pretrain_timesteps} timesteps"
1707+
)
1708+
self.trajectory_generator.unsupervised_pretrain(agent_pretrain_timesteps)
1709+
16881710
for i, num_pairs in enumerate(preference_query_schedule):
16891711
##########################
16901712
# Gather new preferences #

0 commit comments

Comments
 (0)