3
3
Can be used as a CLI script, or the `train_preference_comparisons` function
4
4
can be called directly.
5
5
"""
6
-
7
6
import functools
8
7
import pathlib
9
8
from typing import Any , Mapping , Optional , Type , Union
10
9
10
+ import numpy as np
11
11
import torch as th
12
12
from sacred .observers import FileStorageObserver
13
- from stable_baselines3 .common import type_aliases
13
+ from stable_baselines3 .common import type_aliases , base_class , vec_env
14
14
15
15
from imitation .algorithms import preference_comparisons
16
+ from imitation .algorithms .pebble .entropy_reward import PebbleStateEntropyReward
16
17
from imitation .data import types
17
18
from imitation .policies import serialize
19
+ from imitation .rewards import reward_nets , reward_function
18
20
from imitation .scripts .common import common , reward
19
21
from imitation .scripts .common import rl as rl_common
20
22
from imitation .scripts .common import train
21
23
from imitation .scripts .config .train_preference_comparisons import (
22
24
train_preference_comparisons_ex ,
23
25
)
26
+ from imitation .util import logger as imit_logger
24
27
25
28
26
29
def save_model (
@@ -57,6 +60,59 @@ def save_checkpoint(
57
60
)
58
61
59
62
63
+ @train_preference_comparisons_ex .capture
64
+ def make_reward_function (
65
+ reward_net : reward_nets .RewardNet ,
66
+ * ,
67
+ pebble_enabled : bool = False ,
68
+ pebble_nearest_neighbor_k : Optional [int ] = None ,
69
+ ):
70
+ relabel_reward_fn = functools .partial (
71
+ reward_net .predict_processed ,
72
+ update_stats = False ,
73
+ )
74
+ if pebble_enabled :
75
+ relabel_reward_fn = PebbleStateEntropyReward (
76
+ relabel_reward_fn , pebble_nearest_neighbor_k
77
+ )
78
+ return relabel_reward_fn
79
+
80
+
81
+ @train_preference_comparisons_ex .capture
82
+ def make_agent_trajectory_generator (
83
+ venv : vec_env .VecEnv ,
84
+ agent : base_class .BaseAlgorithm ,
85
+ reward_net : reward_nets .RewardNet ,
86
+ relabel_reward_fn : reward_function .RewardFn ,
87
+ rng : np .random .Generator ,
88
+ custom_logger : Optional [imit_logger .HierarchicalLogger ],
89
+ * ,
90
+ exploration_frac : float ,
91
+ pebble_enabled : bool ,
92
+ trajectory_generator_kwargs : Mapping [str , Any ],
93
+ ) -> preference_comparisons .AgentTrainer :
94
+ if pebble_enabled :
95
+ return preference_comparisons .PebbleAgentTrainer (
96
+ algorithm = agent ,
97
+ reward_fn = relabel_reward_fn ,
98
+ venv = venv ,
99
+ exploration_frac = exploration_frac ,
100
+ rng = rng ,
101
+ custom_logger = custom_logger ,
102
+ ** trajectory_generator_kwargs ,
103
+ )
104
+ else :
105
+ return preference_comparisons .AgentTrainer (
106
+ algorithm = agent ,
107
+ reward_fn = reward_net ,
108
+ venv = venv ,
109
+ exploration_frac = exploration_frac ,
110
+ rng = rng ,
111
+ custom_logger = custom_logger ,
112
+ ** trajectory_generator_kwargs ,
113
+ )
114
+
115
+
60
116
@train_preference_comparisons_ex .main
61
117
def train_preference_comparisons (
62
118
total_timesteps : int ,
@@ -83,7 +139,6 @@ def train_preference_comparisons(
83
139
checkpoint_interval : int ,
84
140
query_schedule : Union [str , type_aliases .Schedule ],
85
141
unsupervised_agent_pretrain_frac : Optional [float ],
86
- pebble_nearest_neighbor_k : Optional [int ],
87
142
) -> Mapping [str , Any ]:
88
143
"""Train a reward model using preference comparisons.
89
144
@@ -146,8 +201,6 @@ def train_preference_comparisons(
146
201
unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
147
202
agent will be trained without preference gathering (and reward model
148
203
training)
149
- pebble_nearest_neighbor_k: Parameter for state entropy computation (for PEBBLE
150
- training only)
151
204
152
205
Returns:
153
206
Rollout statistics from trained policy.
@@ -160,10 +213,8 @@ def train_preference_comparisons(
160
213
161
214
with common .make_venv () as venv :
162
215
reward_net = reward .make_reward_net (venv )
163
- relabel_reward_fn = functools .partial (
164
- reward_net .predict_processed ,
165
- update_stats = False ,
166
- )
216
+ relabel_reward_fn = make_reward_function (reward_net )
217
+
167
218
if agent_path is None :
168
219
agent = rl_common .make_rl_algo (venv , relabel_reward_fn = relabel_reward_fn )
169
220
else :
@@ -176,21 +227,17 @@ def train_preference_comparisons(
176
227
if trajectory_path is None :
177
228
# Setting the logger here is not necessary (PreferenceComparisons takes care
178
229
# of it automatically) but it avoids creating unnecessary loggers.
179
- agent_trainer = preference_comparisons .AgentTrainer (
180
- algorithm = agent ,
181
- reward_fn = reward_net ,
230
+ trajectory_generator = make_agent_trajectory_generator (
182
231
venv = venv ,
183
- exploration_frac = exploration_frac ,
232
+ agent = agent ,
233
+ reward_net = reward_net ,
234
+ relabel_reward_fn = relabel_reward_fn ,
184
235
rng = rng ,
185
236
custom_logger = custom_logger ,
186
- ** trajectory_generator_kwargs ,
187
237
)
188
238
# Stable Baselines will automatically occupy GPU 0 if it is available.
189
239
# Let's use the same device as the SB3 agent for the reward model.
190
- reward_net = reward_net .to (agent_trainer .algorithm .device )
191
- trajectory_generator : preference_comparisons .TrajectoryGenerator = (
192
- agent_trainer
193
- )
240
+ reward_net = reward_net .to (trajectory_generator .algorithm .device )
194
241
else :
195
242
if exploration_frac > 0 :
196
243
raise ValueError (
0 commit comments