@@ -75,6 +75,19 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
75
75
be the environment rewards, not ones from a reward model).
76
76
""" # noqa: DAR202
77
77
78
+ def unsupervised_pretrain (self , steps : int , ** kwargs : Any ) -> None :
79
+ """Pre-train an agent if the trajectory generator uses one that
80
+ needs pre-training.
81
+
82
+ By default, this method does nothing and doesn't need
83
+ to be overridden in subclasses that don't require pre-training.
84
+
85
+ Args:
86
+ steps: number of environment steps to train for.
87
+ **kwargs: additional keyword arguments to pass on to
88
+ the training procedure.
89
+ """
90
+
78
91
def train (self , steps : int , ** kwargs : Any ) -> None :
79
92
"""Train an agent if the trajectory generator uses one.
80
93
@@ -1493,7 +1506,7 @@ def __init__(
1493
1506
transition_oversampling : float = 1 ,
1494
1507
initial_comparison_frac : float = 0.1 ,
1495
1508
initial_epoch_multiplier : float = 200.0 ,
1496
- initial_agent_pretrain_frac : float = 0.01 ,
1509
+ initial_agent_pretrain_frac : float = 0.05 ,
1497
1510
custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
1498
1511
allow_variable_horizon : bool = False ,
1499
1512
rng : Optional [np .random .Generator ] = None ,
@@ -1685,6 +1698,15 @@ def train(
1685
1698
reward_loss = None
1686
1699
reward_accuracy = None
1687
1700
1701
+ ###################################################
1702
+ # Pre-training agent before gathering preferences #
1703
+ ###################################################
1704
+ with self .logger .accumulate_means ("agent" ):
1705
+ self .logger .log (
1706
+ f"Pre-training agent for { agent_pretrain_timesteps } timesteps"
1707
+ )
1708
+ self .trajectory_generator .unsupervised_pretrain (agent_pretrain_timesteps )
1709
+
1688
1710
for i , num_pairs in enumerate (preference_query_schedule ):
1689
1711
##########################
1690
1712
# Gather new preferences #
0 commit comments