|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +"""Offline-to-online SAC fine-tuning. |
| 6 | +
|
| 7 | +Warm-starts SAC on an offline dataset (D4RL/Minari) and fine-tunes it online via |
| 8 | +:class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`, sampling a mixed |
| 9 | +offline/online batch whose offline fraction is annealed to zero over |
| 10 | +``--anneal-frames`` collected frames. |
| 11 | +
|
| 12 | +Example:: |
| 13 | +
|
| 14 | + python train.py --dataset d4rl:halfcheetah-medium-v2 --env HalfCheetah-v4 |
| 15 | + python train.py --dataset minari:mujoco/halfcheetah/expert-v0 --total-frames 200000 |
| 16 | +
|
| 17 | +Requires the dataset backend (``pip install d4rl`` or ``pip install minari``) and |
| 18 | +the matching MuJoCo environment. |
| 19 | +""" |
| 20 | + |
| 21 | +from __future__ import annotations |
| 22 | + |
| 23 | +import argparse |
| 24 | + |
| 25 | +import torch |
| 26 | +from tensordict.nn import NormalParamExtractor, TensorDictModule |
| 27 | +from torch import nn |
| 28 | + |
| 29 | +from torchrl.collectors import Collector |
| 30 | +from torchrl.data import OfflineToOnlineReplayBuffer |
| 31 | +from torchrl.data.datasets.utils import load_dataset |
| 32 | +from torchrl.envs import DoubleToFloat, GymEnv, TransformedEnv |
| 33 | +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator |
| 34 | +from torchrl.objectives import SACLoss, SoftUpdate |
| 35 | +from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer |
| 36 | + |
| 37 | + |
| 38 | +def make_sac_modules(env, num_cells, device): |
| 39 | + obs_dim = env.observation_spec["observation"].shape[-1] |
| 40 | + action_dim = env.action_spec.shape[-1] |
| 41 | + |
| 42 | + actor_net = nn.Sequential( |
| 43 | + MLP( |
| 44 | + in_features=obs_dim, |
| 45 | + out_features=2 * action_dim, |
| 46 | + num_cells=num_cells, |
| 47 | + device=device, |
| 48 | + ), |
| 49 | + NormalParamExtractor(), |
| 50 | + ) |
| 51 | + actor = ProbabilisticActor( |
| 52 | + module=TensorDictModule( |
| 53 | + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] |
| 54 | + ), |
| 55 | + in_keys=["loc", "scale"], |
| 56 | + spec=env.action_spec, |
| 57 | + distribution_class=TanhNormal, |
| 58 | + distribution_kwargs={ |
| 59 | + "low": env.action_spec.space.low, |
| 60 | + "high": env.action_spec.space.high, |
| 61 | + }, |
| 62 | + return_log_prob=True, |
| 63 | + ) |
| 64 | + qvalue = ValueOperator( |
| 65 | + MLP( |
| 66 | + in_features=obs_dim + action_dim, |
| 67 | + out_features=1, |
| 68 | + num_cells=num_cells, |
| 69 | + device=device, |
| 70 | + ), |
| 71 | + in_keys=["observation", "action"], |
| 72 | + out_keys=["state_action_value"], |
| 73 | + ) |
| 74 | + return actor, qvalue |
| 75 | + |
| 76 | + |
| 77 | +def main(): |
| 78 | + parser = argparse.ArgumentParser(description=__doc__) |
| 79 | + parser.add_argument("--env", default="HalfCheetah-v4", help="online gym env id") |
| 80 | + parser.add_argument( |
| 81 | + "--dataset", |
| 82 | + default="d4rl:halfcheetah-medium-v2", |
| 83 | + help="offline dataset id ('d4rl:<id>' or 'minari:<id>')", |
| 84 | + ) |
| 85 | + parser.add_argument("--total-frames", type=int, default=1_000_000) |
| 86 | + parser.add_argument("--frames-per-batch", type=int, default=1000) |
| 87 | + parser.add_argument( |
| 88 | + "--anneal-frames", |
| 89 | + type=int, |
| 90 | + default=None, |
| 91 | + help="frames over which the offline fraction decays to 0 (default: half " |
| 92 | + "of --total-frames)", |
| 93 | + ) |
| 94 | + parser.add_argument("--offline-fraction", type=float, default=0.5) |
| 95 | + parser.add_argument("--online-capacity", type=int, default=1_000_000) |
| 96 | + parser.add_argument("--batch-size", type=int, default=256) |
| 97 | + parser.add_argument("--utd", type=int, default=64, help="optim steps per batch") |
| 98 | + parser.add_argument("--lr", type=float, default=3e-4) |
| 99 | + parser.add_argument("--num-cells", type=int, nargs="+", default=[256, 256]) |
| 100 | + parser.add_argument("--tau", type=float, default=0.001) |
| 101 | + parser.add_argument("--seed", type=int, default=42) |
| 102 | + parser.add_argument("--device", default="cpu") |
| 103 | + args = parser.parse_args() |
| 104 | + |
| 105 | + torch.manual_seed(args.seed) |
| 106 | + device = torch.device(args.device) |
| 107 | + |
| 108 | + # Online environment. |
| 109 | + env = TransformedEnv(GymEnv(args.env, device=device), DoubleToFloat()) |
| 110 | + env.set_seed(args.seed) |
| 111 | + |
| 112 | + # SAC agent. |
| 113 | + actor, qvalue = make_sac_modules(env, args.num_cells, device) |
| 114 | + loss = SACLoss(actor_network=actor, qvalue_network=qvalue) |
| 115 | + loss.make_value_estimator(gamma=0.99) |
| 116 | + target_net_updater = SoftUpdate(loss, tau=args.tau) |
| 117 | + optimizer = torch.optim.Adam(loss.parameters(), lr=args.lr) |
| 118 | + |
| 119 | + # Immutable offline dataset (DoubleToFloat to match the online float32 stream) |
| 120 | + # paired with a growing online buffer. |
| 121 | + offline = load_dataset(args.dataset) |
| 122 | + offline.append_transform(DoubleToFloat()) |
| 123 | + replay_buffer = OfflineToOnlineReplayBuffer( |
| 124 | + offline_dataset=offline, |
| 125 | + online_capacity=args.online_capacity, |
| 126 | + offline_fraction=args.offline_fraction, |
| 127 | + batch_size=args.batch_size, |
| 128 | + ) |
| 129 | + |
| 130 | + collector = Collector( |
| 131 | + env, |
| 132 | + actor, |
| 133 | + frames_per_batch=args.frames_per_batch, |
| 134 | + total_frames=args.total_frames, |
| 135 | + init_random_frames=0, # the offline dataset already warm-starts learning |
| 136 | + device=device, |
| 137 | + ) |
| 138 | + |
| 139 | + anneal_frames = ( |
| 140 | + args.anneal_frames if args.anneal_frames is not None else args.total_frames // 2 |
| 141 | + ) |
| 142 | + trainer = OfflineToOnlineTrainer( |
| 143 | + collector=collector, |
| 144 | + total_frames=args.total_frames, |
| 145 | + frame_skip=1, |
| 146 | + optim_steps_per_batch=args.utd, |
| 147 | + loss_module=loss, |
| 148 | + replay_buffer=replay_buffer, |
| 149 | + anneal_frames=anneal_frames, |
| 150 | + batch_size=args.batch_size, |
| 151 | + optimizer=optimizer, |
| 152 | + target_net_updater=target_net_updater, |
| 153 | + clip_grad_norm=False, |
| 154 | + ) |
| 155 | + trainer.train() |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + main() |
0 commit comments