Skip to content

Commit f7ba109

Browse files
authored
[Feature] Offline-to-online transition utilities for replay buffers (#3900)
1 parent 2523d28 commit f7ba109

8 files changed

Lines changed: 954 additions & 0 deletions

File tree

docs/source/reference/data_datasets.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,54 @@ The TED format is TorchRL's standard data layout for offline RL datasets. It str
1414
trajectories as nested tensors where each element contains the full trajectory data,
1515
making it efficient for sequence-based sampling and training.
1616

17+
18+
Dataset loading registry
19+
------------------------
20+
21+
Many offline datasets can be constructed from a compact string identifier with
22+
:func:`~torchrl.data.datasets.load_dataset`. The identifier is split as
23+
``"<source>:<dataset-id>"``. The ``source`` prefix selects a registered dataset
24+
factory and the remainder is forwarded as the first constructor argument, with
25+
any extra keyword arguments forwarded unchanged.
26+
27+
TorchRL registers the built-in dataset families at import time, including
28+
``"atari"``, ``"atari_dqn"``, ``"d4rl"``, ``"gen_dgrl"``, ``"lerobot"``,
29+
``"minari"``, ``"openml"``, ``"openx"``, ``"roboset"``, and ``"vd4rl"``.
30+
For example:
31+
32+
.. code-block:: python
33+
34+
from torchrl.data.datasets import load_dataset
35+
36+
dataset = load_dataset("d4rl:halfcheetah-medium-v2", batch_size=256)
37+
minari_dataset = load_dataset(
38+
"minari:mujoco/hopper/expert-v0",
39+
batch_size=256,
40+
split_trajs=True,
41+
)
42+
43+
Projects can add their own sources with
44+
:func:`~torchrl.data.datasets.register_dataset`. A registered factory can be a
45+
callable or a lazy import string of the form ``"module:attribute"``. Factories
46+
are called as ``factory(dataset_id, **kwargs)``.
47+
48+
.. code-block:: python
49+
50+
from torchrl.data.datasets import load_dataset, register_dataset
51+
52+
class MyDataset:
53+
def __init__(self, dataset_id, *, batch_size):
54+
self.dataset_id = dataset_id
55+
self.batch_size = batch_size
56+
57+
register_dataset("my_backend", MyDataset)
58+
dataset = load_dataset("my_backend:my-task-v0", batch_size=128)
59+
60+
register_dataset(
61+
"my_lazy_backend",
62+
"my_package.datasets:MyDataset",
63+
)
64+
1765
.. autosummary::
1866
:toctree: generated/
1967
:template: rl_template.rst
@@ -33,3 +81,10 @@ making it efficient for sequence-based sampling and training.
3381
:template: rl_template_noinherit.rst
3482

3583
datasets.lerobot_columns_to_tensordict
84+
85+
.. autosummary::
86+
:toctree: generated/
87+
:template: rl_template_fun.rst
88+
89+
datasets.load_dataset
90+
datasets.register_dataset

docs/source/reference/data_replaybuffers.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,24 @@ Core Replay Buffer Classes
1414
:template: rl_template.rst
1515

1616
ReplayBuffer
17+
OfflineToOnlineReplayBuffer
1718
ReplayBufferEnsemble
1819
PrioritizedReplayBuffer
1920
TensorDictReplayBuffer
2021
TensorDictPrioritizedReplayBuffer
2122
RayReplayBuffer
2223
RemoteTensorDictReplayBuffer
2324

25+
26+
Offline-to-online helpers
27+
-------------------------
28+
29+
.. autosummary::
30+
:toctree: generated/
31+
:template: rl_template_fun.rst
32+
33+
prefill_replay_buffer
34+
2435
Composable Replay Buffers
2536
-------------------------
2637

0 commit comments

Comments
 (0)