Skip to content

Commit 9d34dbe

Browse files
committed
[BugFix] Replay Buffer prefetch & SliceSampler
Fix pickling for `ReplayBuffer` with prefetch by removing non-picklable prefetch objects and recreating them on unpickle. Fix logging typo in `SliceSampler`. ghstack-source-id: 7a8b94a Pull-Request: #3322
1 parent 7f9ea74 commit 9d34dbe

2 files changed

Lines changed: 18 additions & 2 deletions

File tree

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,13 @@ def __getstate__(self) -> dict[str, Any]:
11581158
state["_replay_lock_placeholder"] = None
11591159
if _futures_lock is not None:
11601160
state["_futures_lock_placeholder"] = None
1161+
# Remove non-picklable prefetch objects - they will be recreated on unpickle
1162+
_prefetch_queue = state.pop("_prefetch_queue", None)
1163+
_prefetch_executor = state.pop("_prefetch_executor", None)
1164+
if _prefetch_queue is not None:
1165+
state["_prefetch_queue_placeholder"] = None
1166+
if _prefetch_executor is not None:
1167+
state["_prefetch_executor_placeholder"] = None
11611168
return state
11621169

11631170
def __setstate__(self, state: dict[str, Any]):
@@ -1176,6 +1183,15 @@ def __setstate__(self, state: dict[str, Any]):
11761183
state.pop("_futures_lock_placeholder")
11771184
_futures_lock = threading.RLock()
11781185
state["_futures_lock"] = _futures_lock
1186+
# Recreate prefetch objects after unpickling if they were present
1187+
if "_prefetch_queue_placeholder" in state:
1188+
state.pop("_prefetch_queue_placeholder")
1189+
state["_prefetch_queue"] = collections.deque()
1190+
if "_prefetch_executor_placeholder" in state:
1191+
state.pop("_prefetch_executor_placeholder")
1192+
state["_prefetch_executor"] = ThreadPoolExecutor(
1193+
max_workers=state["_prefetch_cap"]
1194+
)
11791195
self.__dict__.update(state)
11801196
if rngstate is not None:
11811197
self.set_rng(rng)

torchrl/data/replay_buffers/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,10 +1154,10 @@ def __init__(
11541154
def __getstate__(self):
11551155
if get_spawning_popen() is not None and self.cache_values:
11561156
logger.warning(
1157-
f"It seems you are sharing a {type(self).__name__} across processes with"
1157+
f"It seems you are sharing a {type(self).__name__} across processes with "
11581158
f"cache_values=True. "
11591159
f"While this isn't forbidden and could perfectly work if your dataset "
1160-
f"is unaltered on both processes, remember that calling extend/add on"
1160+
f"is unaltered on both processes, remember that calling extend/add on "
11611161
f"one process will NOT erase the cache on another process's sampler, "
11621162
f"which will cause synchronization issues."
11631163
)

0 commit comments

Comments
 (0)