Skip to content

Commit c50a082

Browse files
committed
[BugFix] Complete offline-to-online trainer wiring
1 parent 459fa96 commit c50a082

7 files changed

Lines changed: 234 additions & 4 deletions

File tree

docs/source/reference/config.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ Training and Optimization Configurations
480480
TrainerConfig
481481
PPOTrainerConfig
482482
SACTrainerConfig
483+
OfflineToOnlineTrainerConfig
483484
DQNTrainerConfig
484485
DDPGTrainerConfig
485486
IQLTrainerConfig
@@ -599,6 +600,7 @@ TorchRL currently provides configuration-driven trainers for the following algor
599600

600601
- **PPO** (on-policy): ``PPOTrainerConfig``, ``PPOLossConfig``
601602
- **SAC** (off-policy, continuous): ``SACTrainerConfig``, ``SACLossConfig``
603+
- **Offline-to-online SAC**: ``OfflineToOnlineTrainerConfig``, ``SACLossConfig``
602604
- **DQN** (off-policy, discrete): ``DQNTrainerConfig``, ``DQNLossConfig``
603605
- **DDPG** (off-policy, continuous): ``DDPGTrainerConfig``, ``DDPGLossConfig``
604606
- **IQL** (offline): ``IQLTrainerConfig``, ``IQLLossConfig``

docs/source/reference/trainers_basics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Algorithm-specific trainers
2626

2727
PPOTrainer
2828
SACTrainer
29+
OfflineToOnlineTrainer
2930
DQNTrainer
3031
DDPGTrainer
3132
IQLTrainer

sota-implementations/offline_to_online/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main():
118118

119119
# Immutable offline dataset (DoubleToFloat to match the online float32 stream)
120120
# paired with a growing online buffer.
121-
offline = load_dataset(args.dataset)
121+
offline = load_dataset(args.dataset, batch_size=args.batch_size)
122122
offline.append_transform(DoubleToFloat())
123123
replay_buffer = OfflineToOnlineReplayBuffer(
124124
offline_dataset=offline,

test/test_offline_to_online.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import argparse
8+
import inspect
89

910
import pytest
1011
import torch
@@ -498,13 +499,21 @@ def test_state_dict_roundtrip(self):
498499
)
499500
hook = OfflineToOnlineReplayBufferHook(rb)
500501
hook.extend(_make_online_data(20))
502+
rb.anneal(step=50, total_steps=100)
501503

502504
rb2 = OfflineToOnlineReplayBuffer(
503-
offline_dataset=_make_offline_buffer(), online_capacity=500, batch_size=16
505+
offline_dataset=_make_offline_buffer(),
506+
online_capacity=500,
507+
offline_fraction=0.8,
508+
batch_size=16,
504509
)
505510
hook2 = OfflineToOnlineReplayBufferHook(rb2)
506511
hook2.load_state_dict(hook.state_dict())
507512
assert len(rb2.online_buffer) == 20
513+
assert rb2.offline_fraction == pytest.approx(0.25)
514+
515+
rb2.anneal(step=50, total_steps=100)
516+
assert rb2.offline_fraction == pytest.approx(0.25)
508517

509518

510519
class TestOfflineToOnlineAnnealHook:
@@ -555,6 +564,32 @@ def test_requires_offline_to_online_buffer(self):
555564
replay_buffer=plain,
556565
)
557566

567+
def test_constructor_exposes_sac_key_and_logging_kwargs(self):
568+
from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer
569+
570+
parameters = inspect.signature(OfflineToOnlineTrainer).parameters
571+
for name in (
572+
"log_rewards",
573+
"log_actions",
574+
"log_observations",
575+
"log_timings",
576+
"auto_log_optim_steps",
577+
"done_key",
578+
"terminated_key",
579+
"reward_key",
580+
"episode_reward_key",
581+
"action_key",
582+
"observation_key",
583+
):
584+
assert name in parameters
585+
586+
def test_config_registered(self):
587+
from torchrl.trainers.algorithms.configs import OfflineToOnlineTrainerConfig
588+
589+
assert OfflineToOnlineTrainerConfig._target_.endswith(
590+
"_make_offline_to_online_trainer"
591+
)
592+
558593
def test_hooks_drive_offline_online_flow(self):
559594
"""The three hooks together grow the online buffer, keep the mixed batch
560595
flat, and anneal the offline fraction -- the data path the trainer runs,

torchrl/trainers/algorithms/configs/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
DDPGTrainerConfig,
124124
DQNTrainerConfig,
125125
IQLTrainerConfig,
126+
OfflineToOnlineTrainerConfig,
126127
PPOTrainerConfig,
127128
SACTrainerConfig,
128129
TD3TrainerConfig,
@@ -397,6 +398,7 @@
397398
"DDPGTrainerConfig",
398399
"DQNTrainerConfig",
399400
"IQLTrainerConfig",
401+
"OfflineToOnlineTrainerConfig",
400402
"PPOTrainerConfig",
401403
"SACTrainerConfig",
402404
"TD3TrainerConfig",
@@ -671,6 +673,11 @@ def _register_configs():
671673
cs.store(group="trainer", name="ddpg", node=DDPGTrainerConfig)
672674
cs.store(group="trainer", name="dqn", node=DQNTrainerConfig)
673675
cs.store(group="trainer", name="iql", node=IQLTrainerConfig)
676+
cs.store(
677+
group="trainer",
678+
name="offline_to_online",
679+
node=OfflineToOnlineTrainerConfig,
680+
)
674681
cs.store(group="trainer", name="ppo", node=PPOTrainerConfig)
675682
cs.store(group="trainer", name="sac", node=SACTrainerConfig)
676683
cs.store(group="trainer", name="td3", node=TD3TrainerConfig)

torchrl/trainers/algorithms/configs/trainers.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchrl.trainers.algorithms.ddpg import DDPGTrainer
2323
from torchrl.trainers.algorithms.dqn import DQNTrainer
2424
from torchrl.trainers.algorithms.iql import IQLTrainer
25+
from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer
2526
from torchrl.trainers.algorithms.ppo import PPOTrainer
2627
from torchrl.trainers.algorithms.sac import SACTrainer
2728
from torchrl.trainers.algorithms.td3 import TD3Trainer
@@ -218,6 +219,147 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
218219
return trainer
219220

220221

222+
@dataclass
223+
class OfflineToOnlineTrainerConfig(SACTrainerConfig):
224+
"""Hydra configuration for
225+
:class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`.
226+
227+
Every kwarg accepted by ``OfflineToOnlineTrainer.__init__`` is exposed as a
228+
field here, with SAC network-construction helper fields inherited from
229+
:class:`SACTrainerConfig`.
230+
"""
231+
232+
anneal_frames: int | None = None
233+
234+
_target_: str = (
235+
"torchrl.trainers.algorithms.configs.trainers."
236+
"_make_offline_to_online_trainer"
237+
)
238+
239+
def __post_init__(self) -> None:
240+
"""Post-initialization hook for offline-to-online trainer configuration."""
241+
super().__post_init__()
242+
243+
244+
def _make_offline_to_online_trainer(*args, **kwargs) -> OfflineToOnlineTrainer:
245+
from torchrl.trainers.trainers import Logger
246+
247+
collector = kwargs.pop("collector")
248+
total_frames = kwargs.pop("total_frames")
249+
if total_frames is None:
250+
total_frames = collector.total_frames
251+
frame_skip = kwargs.pop("frame_skip", 1)
252+
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
253+
loss_module = kwargs.pop("loss_module")
254+
optimizer = kwargs.pop("optimizer")
255+
logger = kwargs.pop("logger")
256+
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
257+
clip_norm = kwargs.pop("clip_norm")
258+
progress_bar = kwargs.pop("progress_bar", True)
259+
replay_buffer = kwargs.pop("replay_buffer")
260+
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
261+
log_interval = kwargs.pop("log_interval", 10000)
262+
save_trainer_file = kwargs.pop("save_trainer_file")
263+
seed = kwargs.pop("seed")
264+
actor_network = kwargs.pop("actor_network")
265+
critic_network = kwargs.pop("critic_network")
266+
kwargs.pop("create_env_fn")
267+
target_net_updater = kwargs.pop("target_net_updater")
268+
async_collection = kwargs.pop("async_collection", False)
269+
if async_collection:
270+
raise ValueError(
271+
"OfflineToOnlineTrainer does not support async_collection."
272+
)
273+
log_timings = kwargs.pop("log_timings", False)
274+
auto_log_optim_steps = kwargs.pop("auto_log_optim_steps", True)
275+
batch_size = kwargs.pop("batch_size", None)
276+
anneal_frames = kwargs.pop("anneal_frames", None)
277+
enable_logging = kwargs.pop("enable_logging", True)
278+
log_rewards = kwargs.pop("log_rewards", True)
279+
log_actions = kwargs.pop("log_actions", True)
280+
log_observations = kwargs.pop("log_observations", False)
281+
done_key = _normalize_hydra_key(kwargs.pop("done_key", "done"))
282+
terminated_key = _normalize_hydra_key(kwargs.pop("terminated_key", "terminated"))
283+
reward_key = _normalize_hydra_key(kwargs.pop("reward_key", "reward"))
284+
episode_reward_key = _normalize_hydra_key(
285+
kwargs.pop("episode_reward_key", "reward_sum")
286+
)
287+
action_key = _normalize_hydra_key(kwargs.pop("action_key", "action"))
288+
observation_key = _normalize_hydra_key(kwargs.pop("observation_key", "observation"))
289+
hooks = kwargs.pop("hooks", None)
290+
291+
# Instantiate networks first
292+
if actor_network is not None and not isinstance(actor_network, torch.nn.Module):
293+
actor_network = actor_network()
294+
if critic_network is not None and not isinstance(critic_network, torch.nn.Module):
295+
critic_network = critic_network()
296+
297+
if not isinstance(collector, BaseCollector):
298+
collector = collector()
299+
300+
if not isinstance(loss_module, LossModule):
301+
# then it's a partial config
302+
loss_module = loss_module(
303+
actor_network=actor_network, critic_network=critic_network
304+
)
305+
if target_net_updater is not None and not isinstance(
306+
target_net_updater, TargetNetUpdater
307+
):
308+
# target_net_updater must be a partial taking the loss as input
309+
target_net_updater = target_net_updater(loss_module)
310+
if not isinstance(optimizer, torch.optim.Optimizer):
311+
# then it's a partial config
312+
optimizer = optimizer(params=loss_module.parameters())
313+
314+
# Quick instance checks
315+
if not isinstance(collector, BaseCollector):
316+
raise ValueError(f"collector must be a BaseCollector, got {type(collector)}")
317+
if not isinstance(loss_module, LossModule):
318+
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
319+
if not isinstance(optimizer, torch.optim.Optimizer):
320+
raise ValueError(
321+
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
322+
)
323+
if not isinstance(logger, Logger) and logger is not None:
324+
raise ValueError(f"logger must be a Logger, got {type(logger)}")
325+
326+
trainer = OfflineToOnlineTrainer(
327+
collector=collector,
328+
total_frames=total_frames,
329+
frame_skip=frame_skip,
330+
optim_steps_per_batch=optim_steps_per_batch,
331+
loss_module=loss_module,
332+
replay_buffer=replay_buffer,
333+
anneal_frames=anneal_frames,
334+
batch_size=batch_size,
335+
optimizer=optimizer,
336+
logger=logger,
337+
clip_grad_norm=clip_grad_norm,
338+
clip_norm=clip_norm,
339+
progress_bar=progress_bar,
340+
seed=seed,
341+
save_trainer_interval=save_trainer_interval,
342+
log_interval=log_interval,
343+
save_trainer_file=save_trainer_file,
344+
enable_logging=enable_logging,
345+
log_rewards=log_rewards,
346+
log_actions=log_actions,
347+
log_observations=log_observations,
348+
target_net_updater=target_net_updater,
349+
async_collection=async_collection,
350+
log_timings=log_timings,
351+
auto_log_optim_steps=auto_log_optim_steps,
352+
done_key=done_key,
353+
terminated_key=terminated_key,
354+
reward_key=reward_key,
355+
episode_reward_key=episode_reward_key,
356+
action_key=action_key,
357+
observation_key=observation_key,
358+
)
359+
_register_trainer_hooks(trainer, hooks)
360+
return trainer
361+
362+
221363
@dataclass
222364
class PPOTrainerConfig(TrainerConfig):
223365
"""Hydra configuration for :class:`~torchrl.trainers.algorithms.PPOTrainer`.

torchrl/trainers/algorithms/offline_to_online.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from collections.abc import Callable
1111

1212
from tensordict import TensorDictBase
13+
from tensordict.utils import NestedKey
14+
from torch import optim
1315

1416
from torchrl.collectors import BaseCollector
1517
from torchrl.data.replay_buffers.offline_to_online import OfflineToOnlineReplayBuffer
@@ -89,10 +91,20 @@ def sample(self, batch: TensorDictBase) -> TensorDictBase:
8991
return sample.to(self.device) if self.device is not None else sample
9092

9193
def state_dict(self) -> dict:
92-
return {"online_buffer": self.replay_buffer.online_buffer.state_dict()}
94+
return {
95+
"online_buffer": self.replay_buffer.online_buffer.state_dict(),
96+
"offline_fraction": self.replay_buffer._offline_fraction,
97+
"base_offline_fraction": self.replay_buffer._base_offline_fraction,
98+
}
9399

94100
def load_state_dict(self, state_dict: dict) -> None:
95101
self.replay_buffer.online_buffer.load_state_dict(state_dict["online_buffer"])
102+
self.replay_buffer._offline_fraction = state_dict.get(
103+
"offline_fraction", self.replay_buffer._offline_fraction
104+
)
105+
self.replay_buffer._base_offline_fraction = state_dict.get(
106+
"base_offline_fraction", self.replay_buffer._base_offline_fraction
107+
)
96108

97109
def register(self, trainer, name: str = "replay_buffer") -> None:
98110
trainer.register_op("pre_epoch", self.extend)
@@ -136,6 +148,10 @@ def register(self, trainer, name: str = "offline_to_online_anneal") -> None:
136148
class OfflineToOnlineTrainer(SACTrainer):
137149
"""A SAC trainer for the offline-pretrain -> online-finetune transition.
138150
151+
See also
152+
:class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig`
153+
for the Hydra configuration counterpart.
154+
139155
Builds on :class:`~torchrl.trainers.algorithms.SACTrainer`, swapping the
140156
plain replay buffer for an :class:`~torchrl.data.OfflineToOnlineReplayBuffer`.
141157
Each collected batch is routed to the online buffer while optimization
@@ -175,7 +191,7 @@ def __init__(
175191
replay_buffer: OfflineToOnlineReplayBuffer,
176192
anneal_frames: int | None = None,
177193
batch_size: int | None = None,
178-
optimizer=None,
194+
optimizer: optim.Optimizer | None = None,
179195
logger: Logger | None = None,
180196
clip_grad_norm: bool = True,
181197
clip_norm: float | None = None,
@@ -185,13 +201,29 @@ def __init__(
185201
log_interval: int = 10000,
186202
save_trainer_file: str | pathlib.Path | None = None,
187203
enable_logging: bool = True,
204+
log_rewards: bool = True,
205+
log_actions: bool = True,
206+
log_observations: bool = False,
188207
target_net_updater: TargetNetUpdater | None = None,
208+
async_collection: bool = False,
209+
log_timings: bool = False,
210+
auto_log_optim_steps: bool = True,
211+
done_key: NestedKey = "done",
212+
terminated_key: NestedKey = "terminated",
213+
reward_key: NestedKey = "reward",
214+
episode_reward_key: NestedKey = "reward_sum",
215+
action_key: NestedKey = "action",
216+
observation_key: NestedKey = "observation",
189217
) -> None:
190218
if not isinstance(replay_buffer, OfflineToOnlineReplayBuffer):
191219
raise TypeError(
192220
"OfflineToOnlineTrainer requires an OfflineToOnlineReplayBuffer, "
193221
f"got {type(replay_buffer).__name__}."
194222
)
223+
if async_collection:
224+
raise ValueError(
225+
"OfflineToOnlineTrainer does not support async_collection."
226+
)
195227

196228
# Let SACTrainer wire up everything except the replay buffer (its
197229
# ReplayBufferTrainer assumes a sampler/priority API the offline-to-online
@@ -213,8 +245,19 @@ def __init__(
213245
save_trainer_file=save_trainer_file,
214246
replay_buffer=None,
215247
enable_logging=enable_logging,
248+
log_rewards=log_rewards,
249+
log_actions=log_actions,
250+
log_observations=log_observations,
216251
target_net_updater=target_net_updater,
217252
async_collection=False,
253+
log_timings=log_timings,
254+
auto_log_optim_steps=auto_log_optim_steps,
255+
done_key=done_key,
256+
terminated_key=terminated_key,
257+
reward_key=reward_key,
258+
episode_reward_key=episode_reward_key,
259+
action_key=action_key,
260+
observation_key=observation_key,
218261
)
219262

220263
self.replay_buffer = replay_buffer

0 commit comments

Comments
 (0)