Skip to content

Commit ad8d76e

Browse files
author
Jan Michelfeit
committed
#625 merge pebble to train_preference_comparisons.py and configure only through sacred
1 parent 152efa6 commit ad8d76e

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def __init__(
344344
reward_fn: PebbleStateEntropyReward,
345345
**kwargs,
346346
) -> None:
347+
if not isinstance(reward_fn, PebbleStateEntropyReward):
348+
raise ValueError(
349+
f"{self.__class__.__name__} expects {PebbleStateEntropyReward.__name__} reward function"
350+
)
347351
super().__init__(reward_fn=reward_fn, **kwargs)
348352

349353
def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Configuration for imitation.scripts.train_preference_comparisons."""
22

33
import sacred
4+
import stable_baselines3 as sb3
45

56
from imitation.algorithms import preference_comparisons
7+
from imitation.policies import base
68
from imitation.scripts.common import common, reward, rl, train
79

810
train_preference_comparisons_ex = sacred.Experiment(
@@ -15,7 +17,6 @@
1517
],
1618
)
1719

18-
1920
MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1)))
2021
ANT_SHARED_LOCALS = dict(
2122
total_timesteps=int(3e7),
@@ -61,6 +62,26 @@ def train_defaults():
6162
query_schedule = "hyperbolic"
6263

6364

65+
@train_preference_comparisons_ex.named_config
66+
def pebble():
67+
# fraction of total_timesteps for training before preference gathering
68+
unsupervised_agent_pretrain_frac = 0.05
69+
pebble_nearest_neighbor_k = 5
70+
71+
rl = {
72+
"rl_cls": sb3.SAC,
73+
"batch_size": 256, # batch size for RL algorithm
74+
"rl_kwargs": {"batch_size": None}, # make sure to set batch size to None
75+
}
76+
train = {
77+
"policy_cls": base.SAC1024Policy, # noqa: F841
78+
}
79+
common = {"env_name": "MountainCarContinuous-v0"}
80+
allow_variable_horizon = True
81+
82+
locals() # quieten flake8
83+
84+
6485
@train_preference_comparisons_ex.named_config
6586
def cartpole():
6687
common = dict(env_name="CartPole-v1")
@@ -121,6 +142,7 @@ def fast():
121142
total_timesteps = 50
122143
total_comparisons = 5
123144
initial_comparison_frac = 0.2
145+
unsupervised_agent_pretrain_frac = 0.2
124146
num_iterations = 1
125147
fragment_length = 2
126148
reward_trainer_kwargs = {

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def train_preference_comparisons(
8282
allow_variable_horizon: bool,
8383
checkpoint_interval: int,
8484
query_schedule: Union[str, type_aliases.Schedule],
85+
unsupervised_agent_pretrain_frac: Optional[float],
86+
pebble_nearest_neighbor_k: Optional[int],
8587
) -> Mapping[str, Any]:
8688
"""Train a reward model using preference comparisons.
8789
@@ -141,6 +143,11 @@ def train_preference_comparisons(
141143
be allocated to each iteration. "hyperbolic" and "inverse_quadratic"
142144
apportion fewer queries to later iterations when the policy is assumed
143145
to be better and more stable.
146+
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
147+
agent will be trained without preference gathering (and reward model
148+
training)
149+
pebble_nearest_neighbor_k: Parameter for state entropy computation (for PEBBLE
150+
training only)
144151
145152
Returns:
146153
Rollout statistics from trained policy.
@@ -244,6 +251,7 @@ def train_preference_comparisons(
244251
custom_logger=custom_logger,
245252
allow_variable_horizon=allow_variable_horizon,
246253
query_schedule=query_schedule,
254+
unsupervised_agent_pretrain_frac=unsupervised_agent_pretrain_frac,
247255
)
248256

249257
def save_callback(iteration_num):

0 commit comments

Comments
 (0)