@@ -1506,7 +1506,7 @@ def __init__(
1506
1506
transition_oversampling : float = 1 ,
1507
1507
initial_comparison_frac : float = 0.1 ,
1508
1508
initial_epoch_multiplier : float = 200.0 ,
1509
- initial_agent_pretrain_frac : float = 0.05 ,
1509
+ unsupervised_agent_pretrain_frac : float = 0.05 ,
1510
1510
custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
1511
1511
allow_variable_horizon : bool = False ,
1512
1512
rng : Optional [np .random .Generator ] = None ,
@@ -1556,7 +1556,7 @@ def __init__(
1556
1556
initial_epoch_multiplier: before agent training begins, train the reward
1557
1557
model for this many more epochs than usual (on fragments sampled from a
1558
1558
random agent).
1559
- initial_agent_pretrain_frac : fraction of total_timesteps for which the
1559
+ unsupervised_agent_pretrain_frac : fraction of total_timesteps for which the
1560
1560
agent will be trained without preference gathering (and reward model
1561
1561
training)
1562
1562
custom_logger: Where to log to; if None (default), creates a new logger.
@@ -1657,7 +1657,7 @@ def __init__(
1657
1657
self .fragment_length = fragment_length
1658
1658
self .initial_comparison_frac = initial_comparison_frac
1659
1659
self .initial_epoch_multiplier = initial_epoch_multiplier
1660
- self .initial_agent_pretrain_frac = initial_agent_pretrain_frac
1660
+ self .unsupervised_agent_pretrain_frac = unsupervised_agent_pretrain_frac
1661
1661
self .num_iterations = num_iterations
1662
1662
self .transition_oversampling = transition_oversampling
1663
1663
if callable (query_schedule ):
@@ -1691,7 +1691,7 @@ def train(
1691
1691
print (f"Query schedule: { preference_query_schedule } " )
1692
1692
1693
1693
(
1694
- agent_pretrain_timesteps ,
1694
+ unsupervised_pretrain_timesteps ,
1695
1695
timesteps_per_iteration ,
1696
1696
extra_timesteps ,
1697
1697
) = self ._compute_timesteps (total_timesteps )
@@ -1703,9 +1703,9 @@ def train(
1703
1703
###################################################
1704
1704
with self .logger .accumulate_means ("agent" ):
1705
1705
self .logger .log (
1706
- f"Pre-training agent for { agent_pretrain_timesteps } timesteps"
1706
+ f"Pre-training agent for { unsupervised_pretrain_timesteps } timesteps"
1707
1707
)
1708
- self .trajectory_generator .unsupervised_pretrain (agent_pretrain_timesteps )
1708
+ self .trajectory_generator .unsupervised_pretrain (unsupervised_pretrain_timesteps )
1709
1709
1710
1710
for i , num_pairs in enumerate (preference_query_schedule ):
1711
1711
##########################
@@ -1782,11 +1782,11 @@ def _preference_gather_schedule(self, total_comparisons):
1782
1782
return schedule
1783
1783
1784
1784
def _compute_timesteps (self , total_timesteps : int ) -> Tuple [int , int , int ]:
1785
- agent_pretrain_timesteps = int (
1786
- total_timesteps * self .initial_agent_pretrain_frac
1785
+ unsupervised_pretrain_timesteps = int (
1786
+ total_timesteps * self .unsupervised_agent_pretrain_frac
1787
1787
)
1788
1788
timesteps_per_iteration , extra_timesteps = divmod (
1789
- total_timesteps - agent_pretrain_timesteps ,
1789
+ total_timesteps - unsupervised_pretrain_timesteps ,
1790
1790
self .num_iterations ,
1791
1791
)
1792
- return agent_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
1792
+ return unsupervised_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
0 commit comments