Skip to content

Commit 07dbcac

Browse files
committed
[Doc] Document dataset loader registry
1 parent 0bcc820 commit 07dbcac

1 file changed

Lines changed: 48 additions & 0 deletions

File tree

docs/source/reference/data_datasets.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,54 @@ The TED format is TorchRL's standard data layout for offline RL datasets. It str
1414
trajectories as nested tensors where each element contains the full trajectory data,
1515
making it efficient for sequence-based sampling and training.
1616

17+
18+
Dataset loading registry
19+
------------------------
20+
21+
Many offline datasets can be constructed from a compact string identifier with
22+
:func:`~torchrl.data.datasets.load_dataset`. The identifier is split as
23+
``"<source>:<dataset-id>"``. The ``source`` prefix selects a registered dataset
24+
factory and the remainder is forwarded as the first constructor argument, with
25+
any extra keyword arguments forwarded unchanged.
26+
27+
TorchRL registers the built-in dataset families at import time, including
28+
``"atari"``, ``"atari_dqn"``, ``"d4rl"``, ``"gen_dgrl"``, ``"lerobot"``,
29+
``"minari"``, ``"openml"``, ``"openx"``, ``"roboset"``, and ``"vd4rl"``.
30+
For example:
31+
32+
.. code-block:: python
33+
34+
from torchrl.data.datasets import load_dataset
35+
36+
dataset = load_dataset("d4rl:halfcheetah-medium-v2", batch_size=256)
37+
minari_dataset = load_dataset(
38+
"minari:mujoco/hopper/expert-v0",
39+
batch_size=256,
40+
split_trajs=True,
41+
)
42+
43+
Projects can add their own sources with
44+
:func:`~torchrl.data.datasets.register_dataset`. A registered factory can be a
45+
callable or a lazy import string of the form ``"module:attribute"``. Factories
46+
are called as ``factory(dataset_id, **kwargs)``.
47+
48+
.. code-block:: python
49+
50+
from torchrl.data.datasets import load_dataset, register_dataset
51+
52+
class MyDataset:
53+
def __init__(self, dataset_id, *, batch_size):
54+
self.dataset_id = dataset_id
55+
self.batch_size = batch_size
56+
57+
register_dataset("my_backend", MyDataset)
58+
dataset = load_dataset("my_backend:my-task-v0", batch_size=128)
59+
60+
register_dataset(
61+
"my_lazy_backend",
62+
"my_package.datasets:MyDataset",
63+
)
64+
1765
.. autosummary::
1866
:toctree: generated/
1967
:template: rl_template.rst

0 commit comments

Comments
 (0)