Skip to content

Commit 2ab0780

Browse files
author
Jan Michelfeit
committed
#625 plug in pebble according to parameters
1 parent ad8d76e commit 2ab0780

File tree

4 files changed

+68
-473
lines changed

4 files changed

+68
-473
lines changed

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@ def train_defaults():
6060

6161
checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
6262
query_schedule = "hyperbolic"
63+
# Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf)
64+
pebble_enabled = False
6365

6466

6567
@train_preference_comparisons_ex.named_config
6668
def pebble():
6769
# fraction of total_timesteps for training before preference gathering
70+
pebble_enabled = True
6871
unsupervised_agent_pretrain_frac = 0.05
6972
pebble_nearest_neighbor_k = 5
7073

src/imitation/scripts/config/train_preference_comparisons_pebble.py

Lines changed: 0 additions & 163 deletions
This file was deleted.

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,27 @@
33
Can be used as a CLI script, or the `train_preference_comparisons` function
44
can be called directly.
55
"""
6-
76
import functools
87
import pathlib
98
from typing import Any, Mapping, Optional, Type, Union
109

10+
import numpy as np
1111
import torch as th
1212
from sacred.observers import FileStorageObserver
13-
from stable_baselines3.common import type_aliases
13+
from stable_baselines3.common import type_aliases, base_class, vec_env
1414

1515
from imitation.algorithms import preference_comparisons
16+
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1617
from imitation.data import types
1718
from imitation.policies import serialize
19+
from imitation.rewards import reward_nets, reward_function
1820
from imitation.scripts.common import common, reward
1921
from imitation.scripts.common import rl as rl_common
2022
from imitation.scripts.common import train
2123
from imitation.scripts.config.train_preference_comparisons import (
2224
train_preference_comparisons_ex,
2325
)
26+
from imitation.util import logger as imit_logger
2427

2528

2629
def save_model(
@@ -57,6 +60,59 @@ def save_checkpoint(
5760
)
5861

5962

63+
@train_preference_comparisons_ex.capture
64+
def make_reward_function(
65+
reward_net: reward_nets.RewardNet,
66+
*,
67+
pebble_enabled: bool = False,
68+
pebble_nearest_neighbor_k: Optional[int] = None,
69+
):
70+
relabel_reward_fn = functools.partial(
71+
reward_net.predict_processed,
72+
update_stats=False,
73+
)
74+
if pebble_enabled:
75+
relabel_reward_fn = PebbleStateEntropyReward(
76+
relabel_reward_fn, pebble_nearest_neighbor_k
77+
)
78+
return relabel_reward_fn
79+
80+
81+
@train_preference_comparisons_ex.capture
82+
def make_agent_trajectory_generator(
83+
venv: vec_env.VecEnv,
84+
agent: base_class.BaseAlgorithm,
85+
reward_net: reward_nets.RewardNet,
86+
relabel_reward_fn: reward_function.RewardFn,
87+
rng: np.random.Generator,
88+
custom_logger: Optional[imit_logger.HierarchicalLogger],
89+
*,
90+
exploration_frac: float,
91+
pebble_enabled: bool,
92+
trajectory_generator_kwargs: Mapping[str, Any],
93+
) -> preference_comparisons.AgentTrainer:
94+
if pebble_enabled:
95+
return preference_comparisons.PebbleAgentTrainer(
96+
algorithm=agent,
97+
reward_fn=relabel_reward_fn,
98+
venv=venv,
99+
exploration_frac=exploration_frac,
100+
rng=rng,
101+
custom_logger=custom_logger,
102+
**trajectory_generator_kwargs,
103+
)
104+
else:
105+
return preference_comparisons.AgentTrainer(
106+
algorithm=agent,
107+
reward_fn=reward_net,
108+
venv=venv,
109+
exploration_frac=exploration_frac,
110+
rng=rng,
111+
custom_logger=custom_logger,
112+
**trajectory_generator_kwargs,
113+
)
114+
115+
60116
@train_preference_comparisons_ex.main
61117
def train_preference_comparisons(
62118
total_timesteps: int,
@@ -83,7 +139,6 @@ def train_preference_comparisons(
83139
checkpoint_interval: int,
84140
query_schedule: Union[str, type_aliases.Schedule],
85141
unsupervised_agent_pretrain_frac: Optional[float],
86-
pebble_nearest_neighbor_k: Optional[int],
87142
) -> Mapping[str, Any]:
88143
"""Train a reward model using preference comparisons.
89144
@@ -146,8 +201,6 @@ def train_preference_comparisons(
146201
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
147202
agent will be trained without preference gathering (and reward model
148203
training)
149-
pebble_nearest_neighbor_k: Parameter for state entropy computation (for PEBBLE
150-
training only)
151204
152205
Returns:
153206
Rollout statistics from trained policy.
@@ -160,10 +213,8 @@ def train_preference_comparisons(
160213

161214
with common.make_venv() as venv:
162215
reward_net = reward.make_reward_net(venv)
163-
relabel_reward_fn = functools.partial(
164-
reward_net.predict_processed,
165-
update_stats=False,
166-
)
216+
relabel_reward_fn = make_reward_function(reward_net)
217+
167218
if agent_path is None:
168219
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn)
169220
else:
@@ -176,21 +227,17 @@ def train_preference_comparisons(
176227
if trajectory_path is None:
177228
# Setting the logger here is not necessary (PreferenceComparisons takes care
178229
# of it automatically) but it avoids creating unnecessary loggers.
179-
agent_trainer = preference_comparisons.AgentTrainer(
180-
algorithm=agent,
181-
reward_fn=reward_net,
230+
trajectory_generator = make_agent_trajectory_generator(
182231
venv=venv,
183-
exploration_frac=exploration_frac,
232+
agent=agent,
233+
reward_net=reward_net,
234+
relabel_reward_fn=relabel_reward_fn,
184235
rng=rng,
185236
custom_logger=custom_logger,
186-
**trajectory_generator_kwargs,
187237
)
188238
# Stable Baselines will automatically occupy GPU 0 if it is available.
189239
# Let's use the same device as the SB3 agent for the reward model.
190-
reward_net = reward_net.to(agent_trainer.algorithm.device)
191-
trajectory_generator: preference_comparisons.TrajectoryGenerator = (
192-
agent_trainer
193-
)
240+
reward_net = reward_net.to(trajectory_generator.algorithm.device)
194241
else:
195242
if exploration_frac > 0:
196243
raise ValueError(

0 commit comments

Comments
 (0)