Skip to content

Commit e59b517

Browse files
committed
add filter percentile logs; make d4rl an offline-to-online example
1 parent 12c9231 commit e59b517

4 files changed

Lines changed: 58 additions & 19 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ Off-policy learning makes it easy to relabel old sequence data with new rewards.
245245
<br>
246246

247247

248-
### **14. Offline RL: D4RL**
248+
### **14. Offline-to-Online RL: D4RL**
249249
**[`14_d4rl.py`](examples/14_d4rl.py)**
250250

251251
<img src="docs/media/d4rl.png" alt="d4rl_diagram" width="100" align="left" />

amago/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,8 @@ def masked_avg(x_, dim=0):
662662
"Filter Max": filter_.max(),
663663
"Filter Min": filter_.min(),
664664
"Filter Mean": (mask * filter_).sum() / mask.sum(),
665+
"Filter 95th Percentile": torch.quantile(filter_, 0.95),
666+
"Filter 75th Percentile": torch.quantile(filter_, 0.75),
665667
"Pct. of Actions Approved by Binary FBC Filter (All Gammas)": utils.masked_avg(
666668
binary_filter, mask
667669
)
@@ -920,7 +922,7 @@ def forward(self, batch: Batch, log_step: bool):
920922
logp_a = logp_a[:, :-1, ...]
921923
actor_loss += self.offline_coeff * -(filter_.detach() * logp_a)
922924
if log_step:
923-
filter_stats = self._filter_stats(actor_mask, logp_a, binary_filter_)
925+
filter_stats = self._filter_stats(actor_mask, logp_a, filter_)
924926
self.update_info.update(filter_stats)
925927

926928
if self.online_coeff > 0:

amago/cli_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
and to break up configuration into several smaller steps.
66
"""
77

8+
import os
89
from argparse import ArgumentParser
910
from typing import Optional
1011

1112
import gin
13+
import wandb
1214

1315
import amago
1416
from amago import TrajEncoder, TstepEncoder, Agent

examples/14_d4rl.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import amago
1212
from amago.envs import AMAGOEnv
1313
from amago import cli_utils
14-
from amago.loading import RLData, RLDataset
14+
from amago.loading import RLData, RLDataset, DiskTrajDataset, MixtureOfDatasets
1515
from amago.nets.policy_dists import TanhGaussian, GMM, Beta
1616
from amago.nets.actor_critic import ResidualActor, Actor
17+
from amago.agent import binary_filter, exp_filter
1718

1819

1920
def add_cli(parser):
@@ -26,7 +27,7 @@ def add_cli(parser):
2627
parser.add_argument(
2728
"--policy_dist",
2829
type=str,
29-
default="TanhGaussian",
30+
default="Beta",
3031
help="Policy distribution type",
3132
choices=["TanhGaussian", "GMM", "Beta"],
3233
)
@@ -37,6 +38,12 @@ def add_cli(parser):
3738
help="Actor head type",
3839
choices=["ResidualActor", "Actor"],
3940
)
41+
parser.add_argument(
42+
"--online_after_epoch",
43+
type=int,
44+
default=float("inf"),
45+
help="Number of epochs after which to start collecting online data",
46+
)
4047
parser.add_argument(
4148
"--eval_timesteps",
4249
type=int,
@@ -83,7 +90,6 @@ def _sample_trajectory(self, episode_idx: int):
8390
rewards = torch.from_numpy(rewards_np).float().unsqueeze(-1)
8491
time_idxs = torch.arange(traj_len).unsqueeze(-1).long()
8592
dones = torch.from_numpy(terminals_np).bool().unsqueeze(-1)
86-
8793
return RLData(
8894
obs=obs,
8995
actions=actions,
@@ -128,12 +134,14 @@ def reset(self, *args, **kwargs):
128134

129135
def step(self, action):
130136
s, r, d, i = self.env.step(action)
137+
truncated = i.get("TimeLimit.truncated", False)
138+
terminated = d and not truncated
131139
self.episode_return += r
132-
if d:
140+
if terminated or truncated:
133141
i[f"{AMAGO_ENV_LOG_PREFIX} D4RL Normalized Return"] = (
134142
d4rl.get_normalized_score(self.env_name, self.episode_return)
135143
)
136-
return s, r, d, d, i
144+
return s, r, terminated, truncated, i
137145

138146

139147
if __name__ == "__main__":
@@ -148,13 +156,9 @@ def step(self, action):
148156
assert isinstance(
149157
example_env.action_space, gym.spaces.Box
150158
), "Only supports continuous action spaces"
151-
if args.timesteps_per_epoch > 0:
152-
print("WARNING: timesteps_per_epoch is not supported for D4RL, setting to 0")
153-
args.timesteps_per_epoch = 0
154159

155-
# create dataset
156-
dataset = D4RLDataset(d4rl_dset=example_env.dset)
157160
args.eval_timesteps = example_env.time_limit + 1
161+
args.timesteps_per_epoch = example_env.time_limit
158162

159163
# setup environment
160164
make_train_env = lambda: AMAGOEnv(
@@ -166,8 +170,8 @@ def step(self, action):
166170
# agent architecture: drop everything down to standard small sizes
167171
config = {
168172
"amago.nets.actor_critic.NCritics.d_hidden": 128,
169-
"amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 256,
170-
"amago.nets.actor_critic.NCriticsTwoHot.output_bins": 128,
173+
"amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 128,
174+
"amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
171175
"amago.nets.actor_critic.Actor.d_hidden": 128,
172176
"amago.nets.actor_critic.Actor.continuous_dist_type": eval(args.policy_dist),
173177
"amago.nets.actor_critic.ResidualActor.feature_dim": 128,
@@ -184,6 +188,13 @@ def step(self, action):
184188
d_output=128,
185189
n_layers=1,
186190
)
191+
exploration_wrapper_type = cli_utils.switch_exploration(
192+
config,
193+
strategy="egreedy",
194+
eps_start=0.05,
195+
eps_end=0.01,
196+
steps_anneal=15_000,
197+
)
187198
traj_encoder_type = cli_utils.switch_traj_encoder(
188199
config,
189200
arch=args.traj_encoder,
@@ -195,18 +206,37 @@ def step(self, action):
195206
args.agent_type,
196207
online_coeff=0.0,
197208
offline_coeff=1.0,
198-
gamma=0.995,
209+
gamma=0.997,
199210
reward_multiplier=100.0 if example_env.max_return <= 10.0 else 1,
200-
num_actions_for_value_in_critic_loss=2,
201-
num_actions_for_value_in_actor_loss=4,
202-
num_critics=4,
211+
num_actions_for_value_in_critic_loss=3,
212+
num_actions_for_value_in_actor_loss=5,
213+
num_critics=5,
203214
actor_type=eval(args.actor_type),
215+
fbc_filter_func=exp_filter,
204216
)
205217
cli_utils.use_config(config, args.configs)
206218

207219
group_name = f"{args.run_name}_{env_name}"
208220
for trial in range(args.trials):
209221
run_name = group_name + f"_trial_{trial}"
222+
223+
# create dataset
224+
d4rl_dataset = D4RLDataset(d4rl_dset=example_env.dset)
225+
online_dset = DiskTrajDataset(
226+
dset_root=args.buffer_dir,
227+
dset_name=run_name,
228+
dset_min_size=250,
229+
dset_max_size=args.dset_max_size,
230+
)
231+
combined_dset = MixtureOfDatasets(
232+
datasets=[d4rl_dataset, online_dset],
233+
# skew sampling towards the demos 80/20
234+
sampling_weights=[0.8, 0.2],
235+
# gradually increase the weight of the online dset
236+
# over the first 100 epochs *after online collection starts*
237+
smooth_sudden_starts=50,
238+
)
239+
210240
experiment = cli_utils.create_experiment_from_cli(
211241
args,
212242
make_train_env=make_train_env,
@@ -219,10 +249,15 @@ def step(self, action):
219249
group_name=group_name,
220250
val_timesteps_per_epoch=args.eval_timesteps,
221251
learning_rate=1e-4,
222-
dataset=dataset,
252+
dataset=combined_dset,
223253
padded_sampling="right",
254+
start_collecting_at_epoch=args.online_after_epoch,
255+
stagger_traj_file_lengths=False,
256+
traj_save_len=args.eval_timesteps + 1,
224257
sample_actions=False,
258+
exploration_wrapper_type=exploration_wrapper_type,
225259
)
260+
# save a copy of this script at the time of the run
226261
experiment = cli_utils.switch_async_mode(experiment, args.mode)
227262
experiment.start()
228263
if args.ckpt is not None:

0 commit comments

Comments
 (0)