Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,30 @@ def test_batch_errors():
rb.sample()


def test_storage_save_hook(tmpdir):
observed = {}

class SaveHook:
shift = None
is_full = None

def __call__(self, data, path=None):
observed["shift"] = self.shift
observed["is_full"] = self.is_full
return data

hook = SaveHook()
rb = ReplayBuffer(storage=LazyMemmapStorage(10))
rb.register_save_hook(hook)
rb.extend(torch.arange(5))
rb.dumps(tmpdir)

assert hook.shift == 5, f"Expected shift=5, got {hook.shift}"
assert hook.is_full is False, f"Expected is_full=False, got {hook.is_full}"
assert observed["shift"] == 5
assert observed["is_full"] is False


@pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set")
def test_add_warning():
from torchrl._utils import rl_warnings
Expand Down
91 changes: 54 additions & 37 deletions torchrl/data/replay_buffers/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,52 @@ class StorageCheckpointerBase:

"""

def __init__(self):
self._save_hooks = []
self._load_hooks = []

def register_save_hook(self, hook):
"""Registers a save hook for this checkpointer."""
self._save_hooks.append(hook)

def register_load_hook(self, hook):
"""Registers a load hook for this checkpointer."""
self._load_hooks.append(hook)

def _get_shift_from_last_cursor(self, last_cursor):
"""Computes shift from the last cursor position."""
if isinstance(last_cursor, slice):
return last_cursor.stop + 1
if isinstance(last_cursor, int):
return last_cursor + 1
if isinstance(last_cursor, range):
return last_cursor[-1] + 1
if isinstance(last_cursor, torch.Tensor):
return last_cursor.reshape(-1)[-1].item() + 1
if isinstance(last_cursor, np.ndarray):
return last_cursor.reshape(-1)[-1].item() + 1
raise ValueError(f"Unrecognised last_cursor type {type(last_cursor)}.")

def _set_hooks_shift_is_full(self, storage):
"""Sets shift and is_full attributes on save hooks that have them."""
is_full = storage._is_full
last_cursor = storage._last_cursor
for hook in self._save_hooks:
if hasattr(hook, "is_full"):
hook.is_full = is_full
if last_cursor is None:
warnings.warn(
"last_cursor is None. The replay buffer "
"may not be saved properly in this setting. To solve this issue, make "
"sure the storage updates the _last_cursor value during calls to `set`."
)
shift = 0
else:
shift = self._get_shift_from_last_cursor(last_cursor)
for hook in self._save_hooks:
if hasattr(hook, "shift"):
hook.shift = shift

@abc.abstractmethod
def dumps(self, storage, path):
...
Expand Down Expand Up @@ -287,9 +333,6 @@ class TensorStorageCheckpointer(StorageCheckpointerBase):

"""

_save_hooks = []
_load_hooks = []

def dumps(self, storage, path):
path = Path(path)
path.mkdir(exist_ok=True)
Expand All @@ -298,6 +341,9 @@ def dumps(self, storage, path):
raise RuntimeError("Cannot save a non-initialized storage.")
metadata = {}
_storage = storage._storage

self._set_hooks_shift_is_full(storage)

for hook in self._save_hooks:
_storage = hook(_storage, path=path)
if is_tensor_collection(_storage):
Expand Down Expand Up @@ -424,6 +470,7 @@ class FlatStorageCheckpointer(TensorStorageCheckpointer):
"""

def __init__(self, done_keys=None, reward_keys=None):
super().__init__()
kwargs = {}
if done_keys is not None:
kwargs["done_keys"] = done_keys
Expand All @@ -432,38 +479,6 @@ def __init__(self, done_keys=None, reward_keys=None):
self._save_hooks = [TED2Flat(**kwargs)]
self._load_hooks = [Flat2TED(**kwargs)]

def _save_shift_is_full(self, storage):
is_full = storage._is_full
last_cursor = storage._last_cursor
for hook in self._save_hooks:
if hasattr(hook, "is_full"):
hook.is_full = is_full
if last_cursor is None:
warnings.warn(
"las_cursor is None. The replay buffer "
"may not be saved properly in this setting. To solve this issue, make "
"sure the storage updates the _las_cursor value during calls to `set`."
)
shift = self._get_shift_from_last_cursor(last_cursor)
for hook in self._save_hooks:
if hasattr(hook, "shift"):
hook.shift = shift

def dumps(self, storage, path):
self._save_shift_is_full(storage)
return super().dumps(storage, path)

def _get_shift_from_last_cursor(self, last_cursor):
if isinstance(last_cursor, slice):
return last_cursor.stop + 1
if isinstance(last_cursor, int):
return last_cursor + 1
if isinstance(last_cursor, torch.Tensor):
return last_cursor.reshape(-1)[-1].item() + 1
if isinstance(last_cursor, np.ndarray):
return last_cursor.reshape(-1)[-1].item() + 1
raise ValueError(f"Unrecognised last_cursor type {type(last_cursor)}.")


class NestedStorageCheckpointer(FlatStorageCheckpointer):
"""Saves the storage in a compact form, saving space on the TED format and using memory-mapped nested tensors.
Expand All @@ -478,7 +493,8 @@ class NestedStorageCheckpointer(FlatStorageCheckpointer):

"""

def __init__(self, done_keys=None, reward_keys=None, **kwargs):
def __init__(self, done_keys=None, reward_keys=None):
super().__init__()
kwargs = {}
if done_keys is not None:
kwargs["done_keys"] = done_keys
Expand Down Expand Up @@ -522,6 +538,7 @@ def __init__(
h5_kwargs=None,
**kwargs,
):
StorageCheckpointerBase.__init__(self)
ted2_kwargs = kwargs
if done_keys is not None:
ted2_kwargs["done_keys"] = done_keys
Expand All @@ -535,7 +552,7 @@ def __init__(
def dumps(self, storage, path):
path = self._get_path(path)

self._save_shift_is_full(storage)
self._set_hooks_shift_is_full(storage)

if not storage.initialized:
raise RuntimeError("Cannot save a non-initialized storage.")
Expand Down
14 changes: 14 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ def __init__(
def checkpointer(self):
return self._checkpointer

def register_save_hook(self, hook):
"""Register a save hook for this storage.

The hook is forwarded to the checkpointer.
"""
self._checkpointer.register_save_hook(hook)

def register_load_hook(self, hook):
"""Register a load hook for this storage.

The hook is forwarded to the checkpointer.
"""
self._checkpointer.register_load_hook(hook)

@checkpointer.setter
def checkpointer(self, value: StorageCheckpointerBase | None) -> None:
if value is None:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
{
"params": TensorDict(
{
"max_speed": 8,
"max_speed": 8.0,
"max_torque": 2.0,
"dt": 0.05,
"g": g,
Expand Down
Loading