Skip to content

Commit 716c710

Browse files
author
Jan Michelfeit
committed
#625 rename unsupervised_agent_pretrain_frac parameter
1 parent 15c682a commit 716c710

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def __init__(
15061506
transition_oversampling: float = 1,
15071507
initial_comparison_frac: float = 0.1,
15081508
initial_epoch_multiplier: float = 200.0,
1509-
initial_agent_pretrain_frac: float = 0.05,
1509+
unsupervised_agent_pretrain_frac: float = 0.05,
15101510
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
15111511
allow_variable_horizon: bool = False,
15121512
rng: Optional[np.random.Generator] = None,
@@ -1556,7 +1556,7 @@ def __init__(
15561556
initial_epoch_multiplier: before agent training begins, train the reward
15571557
model for this many more epochs than usual (on fragments sampled from a
15581558
random agent).
1559-
initial_agent_pretrain_frac: fraction of total_timesteps for which the
1559+
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
15601560
agent will be trained without preference gathering (and reward model
15611561
training)
15621562
custom_logger: Where to log to; if None (default), creates a new logger.
@@ -1657,7 +1657,7 @@ def __init__(
16571657
self.fragment_length = fragment_length
16581658
self.initial_comparison_frac = initial_comparison_frac
16591659
self.initial_epoch_multiplier = initial_epoch_multiplier
1660-
self.initial_agent_pretrain_frac = initial_agent_pretrain_frac
1660+
self.unsupervised_agent_pretrain_frac = unsupervised_agent_pretrain_frac
16611661
self.num_iterations = num_iterations
16621662
self.transition_oversampling = transition_oversampling
16631663
if callable(query_schedule):
@@ -1691,7 +1691,7 @@ def train(
16911691
print(f"Query schedule: {preference_query_schedule}")
16921692

16931693
(
1694-
agent_pretrain_timesteps,
1694+
unsupervised_pretrain_timesteps,
16951695
timesteps_per_iteration,
16961696
extra_timesteps,
16971697
) = self._compute_timesteps(total_timesteps)
@@ -1703,9 +1703,9 @@ def train(
17031703
###################################################
17041704
with self.logger.accumulate_means("agent"):
17051705
self.logger.log(
1706-
f"Pre-training agent for {agent_pretrain_timesteps} timesteps"
1706+
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
17071707
)
1708-
self.trajectory_generator.unsupervised_pretrain(agent_pretrain_timesteps)
1708+
self.trajectory_generator.unsupervised_pretrain(unsupervised_pretrain_timesteps)
17091709

17101710
for i, num_pairs in enumerate(preference_query_schedule):
17111711
##########################
@@ -1782,11 +1782,11 @@ def _preference_gather_schedule(self, total_comparisons):
17821782
return schedule
17831783

17841784
def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
1785-
agent_pretrain_timesteps = int(
1786-
total_timesteps * self.initial_agent_pretrain_frac
1785+
unsupervised_pretrain_timesteps = int(
1786+
total_timesteps * self.unsupervised_agent_pretrain_frac
17871787
)
17881788
timesteps_per_iteration, extra_timesteps = divmod(
1789-
total_timesteps - agent_pretrain_timesteps,
1789+
total_timesteps - unsupervised_pretrain_timesteps,
17901790
self.num_iterations,
17911791
)
1792-
return agent_pretrain_timesteps, timesteps_per_iteration, extra_timesteps
1792+
return unsupervised_pretrain_timesteps, timesteps_per_iteration, extra_timesteps

0 commit comments

Comments
 (0)