Skip to content

Commit 2893f5b

Browse files
theap06claude
andcommitted
[Feature] OfflineToOnlineTrainer + sota script for offline->online RL
Follow-up to the OfflineToOnlineReplayBuffer PR: a SAC trainer that drives the offline-pretrain -> online-finetune transition, plus a standalone sota-implementations script. - OfflineToOnlineTrainer (subclasses SACTrainer): routes collected experience to the online buffer (pre_epoch), samples a mixed offline/online batch (process_optim_batch), and anneals the offline fraction to zero over anneal_frames (post_steps). Backed by two reusable hooks: OfflineToOnlineReplayBufferHook (projects online transitions onto the offline dataset schema so the mixed-batch concat stays valid) and OfflineToOnlineAnnealHook. - sota-implementations/offline_to_online/train.py: a self-contained SAC offline->online script (offline dataset via d4rl:/minari: string). - Tests: hook + flow tests and a gated functional train() run on Pendulum. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 6477e4a commit 2893f5b

4 files changed

Lines changed: 628 additions & 0 deletions

File tree

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Offline-to-online SAC fine-tuning.
6+
7+
Warm-starts SAC on an offline dataset (D4RL/Minari) and fine-tunes it online via
8+
:class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`, sampling a mixed
9+
offline/online batch whose offline fraction is annealed to zero over
10+
``--anneal-frames`` collected frames.
11+
12+
Example::
13+
14+
python train.py --dataset d4rl:halfcheetah-medium-v2 --env HalfCheetah-v4
15+
python train.py --dataset minari:mujoco/halfcheetah/expert-v0 --total-frames 200000
16+
17+
Requires the dataset backend (``pip install d4rl`` or ``pip install minari``) and
18+
the matching MuJoCo environment.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import argparse
24+
25+
import torch
26+
from tensordict.nn import NormalParamExtractor, TensorDictModule
27+
from torch import nn
28+
29+
from torchrl.collectors import Collector
30+
from torchrl.data import OfflineToOnlineReplayBuffer
31+
from torchrl.data.datasets.utils import load_dataset
32+
from torchrl.envs import DoubleToFloat, GymEnv, TransformedEnv
33+
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
34+
from torchrl.objectives import SACLoss, SoftUpdate
35+
from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer
36+
37+
38+
def make_sac_modules(env, num_cells, device):
39+
obs_dim = env.observation_spec["observation"].shape[-1]
40+
action_dim = env.action_spec.shape[-1]
41+
42+
actor_net = nn.Sequential(
43+
MLP(
44+
in_features=obs_dim,
45+
out_features=2 * action_dim,
46+
num_cells=num_cells,
47+
device=device,
48+
),
49+
NormalParamExtractor(),
50+
)
51+
actor = ProbabilisticActor(
52+
module=TensorDictModule(
53+
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
54+
),
55+
in_keys=["loc", "scale"],
56+
spec=env.action_spec,
57+
distribution_class=TanhNormal,
58+
distribution_kwargs={
59+
"low": env.action_spec.space.low,
60+
"high": env.action_spec.space.high,
61+
},
62+
return_log_prob=True,
63+
)
64+
qvalue = ValueOperator(
65+
MLP(
66+
in_features=obs_dim + action_dim,
67+
out_features=1,
68+
num_cells=num_cells,
69+
device=device,
70+
),
71+
in_keys=["observation", "action"],
72+
out_keys=["state_action_value"],
73+
)
74+
return actor, qvalue
75+
76+
77+
def main():
78+
parser = argparse.ArgumentParser(description=__doc__)
79+
parser.add_argument("--env", default="HalfCheetah-v4", help="online gym env id")
80+
parser.add_argument(
81+
"--dataset",
82+
default="d4rl:halfcheetah-medium-v2",
83+
help="offline dataset id ('d4rl:<id>' or 'minari:<id>')",
84+
)
85+
parser.add_argument("--total-frames", type=int, default=1_000_000)
86+
parser.add_argument("--frames-per-batch", type=int, default=1000)
87+
parser.add_argument(
88+
"--anneal-frames",
89+
type=int,
90+
default=None,
91+
help="frames over which the offline fraction decays to 0 (default: half "
92+
"of --total-frames)",
93+
)
94+
parser.add_argument("--offline-fraction", type=float, default=0.5)
95+
parser.add_argument("--online-capacity", type=int, default=1_000_000)
96+
parser.add_argument("--batch-size", type=int, default=256)
97+
parser.add_argument("--utd", type=int, default=64, help="optim steps per batch")
98+
parser.add_argument("--lr", type=float, default=3e-4)
99+
parser.add_argument("--num-cells", type=int, nargs="+", default=[256, 256])
100+
parser.add_argument("--tau", type=float, default=0.001)
101+
parser.add_argument("--seed", type=int, default=42)
102+
parser.add_argument("--device", default="cpu")
103+
args = parser.parse_args()
104+
105+
torch.manual_seed(args.seed)
106+
device = torch.device(args.device)
107+
108+
# Online environment.
109+
env = TransformedEnv(GymEnv(args.env, device=device), DoubleToFloat())
110+
env.set_seed(args.seed)
111+
112+
# SAC agent.
113+
actor, qvalue = make_sac_modules(env, args.num_cells, device)
114+
loss = SACLoss(actor_network=actor, qvalue_network=qvalue)
115+
loss.make_value_estimator(gamma=0.99)
116+
target_net_updater = SoftUpdate(loss, tau=args.tau)
117+
optimizer = torch.optim.Adam(loss.parameters(), lr=args.lr)
118+
119+
# Immutable offline dataset (DoubleToFloat to match the online float32 stream)
120+
# paired with a growing online buffer.
121+
offline = load_dataset(args.dataset)
122+
offline.append_transform(DoubleToFloat())
123+
replay_buffer = OfflineToOnlineReplayBuffer(
124+
offline_dataset=offline,
125+
online_capacity=args.online_capacity,
126+
offline_fraction=args.offline_fraction,
127+
batch_size=args.batch_size,
128+
)
129+
130+
collector = Collector(
131+
env,
132+
actor,
133+
frames_per_batch=args.frames_per_batch,
134+
total_frames=args.total_frames,
135+
init_random_frames=0, # the offline dataset already warm-starts learning
136+
device=device,
137+
)
138+
139+
anneal_frames = (
140+
args.anneal_frames if args.anneal_frames is not None else args.total_frames // 2
141+
)
142+
trainer = OfflineToOnlineTrainer(
143+
collector=collector,
144+
total_frames=args.total_frames,
145+
frame_skip=1,
146+
optim_steps_per_batch=args.utd,
147+
loss_module=loss,
148+
replay_buffer=replay_buffer,
149+
anneal_frames=anneal_frames,
150+
batch_size=args.batch_size,
151+
optimizer=optimizer,
152+
target_net_updater=target_net_updater,
153+
clip_grad_norm=False,
154+
)
155+
trainer.train()
156+
157+
158+
if __name__ == "__main__":
159+
main()

0 commit comments

Comments
 (0)