@@ -68,6 +68,7 @@ def train_preference_comparisons(
68
68
fragment_length : int ,
69
69
transition_oversampling : float ,
70
70
initial_comparison_frac : float ,
71
+ initial_epoch_multiplier : float ,
71
72
exploration_frac : float ,
72
73
trajectory_path : Optional [str ],
73
74
trajectory_generator_kwargs : Mapping [str , Any ],
@@ -106,6 +107,9 @@ def train_preference_comparisons(
106
107
sampled before the rest of training begins (using the randomly initialized
107
108
agent). This can be used to pretrain the reward model before the agent
108
109
is trained on the learned reward.
110
+ initial_epoch_multiplier: before agent training begins, train the reward
111
+ model for this many more epochs than usual (on fragments sampled from a
112
+ random agent).
109
113
exploration_frac: fraction of trajectory samples that will be created using
110
114
partially random actions, rather than the current policy. Might be helpful
111
115
if the learned policy explores too little and gets stuck with a wrong
@@ -258,6 +262,7 @@ def train_preference_comparisons(
258
262
fragment_length = fragment_length ,
259
263
transition_oversampling = transition_oversampling ,
260
264
initial_comparison_frac = initial_comparison_frac ,
265
+ initial_epoch_multiplier = initial_epoch_multiplier ,
261
266
custom_logger = custom_logger ,
262
267
allow_variable_horizon = allow_variable_horizon ,
263
268
query_schedule = query_schedule ,
0 commit comments