Skip to content

Commit 78553c9

Browse files
committed
Add initial epoch multiplier as a parameter to the PC script.
1 parent 55aa6eb commit 78553c9

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def train_defaults():
4242
transition_oversampling = 1
4343
# fraction of total_comparisons that will be sampled right at the beginning
4444
initial_comparison_frac = 0.1
45+
# factor by which to oversample the number of epochs in the first iteration
46+
initial_epoch_multiplier = 200.0
4547
# fraction of sampled trajectories that will include some random actions
4648
exploration_frac = 0.0
4749
preference_model_kwargs = {}

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def train_preference_comparisons(
6868
fragment_length: int,
6969
transition_oversampling: float,
7070
initial_comparison_frac: float,
71+
initial_epoch_multiplier: float,
7172
exploration_frac: float,
7273
trajectory_path: Optional[str],
7374
trajectory_generator_kwargs: Mapping[str, Any],
@@ -106,6 +107,9 @@ def train_preference_comparisons(
106107
sampled before the rest of training begins (using the randomly initialized
107108
agent). This can be used to pretrain the reward model before the agent
108109
is trained on the learned reward.
110+
initial_epoch_multiplier: before agent training begins, train the reward
111+
model for this many more epochs than usual (on fragments sampled from a
112+
random agent).
109113
exploration_frac: fraction of trajectory samples that will be created using
110114
partially random actions, rather than the current policy. Might be helpful
111115
if the learned policy explores too little and gets stuck with a wrong
@@ -258,6 +262,7 @@ def train_preference_comparisons(
258262
fragment_length=fragment_length,
259263
transition_oversampling=transition_oversampling,
260264
initial_comparison_frac=initial_comparison_frac,
265+
initial_epoch_multiplier=initial_epoch_multiplier,
261266
custom_logger=custom_logger,
262267
allow_variable_horizon=allow_variable_horizon,
263268
query_schedule=query_schedule,

0 commit comments

Comments
 (0)