Skip to content

Commit 1ae2779

Browse files
authored
[BugFix] Buffer prefetching off by one (#3871)
1 parent ae403e6 commit 1ae2779

2 files changed

Lines changed: 25 additions & 2 deletions

File tree

test/rb/test_rb_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,26 @@ def test_replay_buffer_iter(size, drop_last):
623623
assert i == (size - 1) // 3
624624

625625

626+
def test_replay_buffer_prefetch_queue_length():
627+
"""Test that the prefetch queue maintains the correct length.
628+
629+
This test verifies that after sampling from a replay buffer with prefetching
630+
enabled, the prefetch queue has exactly `prefetch` items computing in the
631+
background (no off-by-one error).
632+
"""
633+
torch.manual_seed(0)
634+
635+
rb = ReplayBuffer(storage=ListStorage(max_size=100), batch_size=2, prefetch=2)
636+
637+
rb.extend(torch.arange(100))
638+
639+
_ = rb.sample()
640+
641+
assert (
642+
len(rb._prefetch_queue) == 2
643+
), f"Expected prefetch queue to have 2 items, but got {len(rb._prefetch_queue)}."
644+
645+
626646
if __name__ == "__main__":
627647
args, unknown = argparse.ArgumentParser().parse_known_args()
628648
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,14 +1154,17 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
11541154
result = self._sample(batch_size)
11551155
else:
11561156
with self._futures_lock:
1157+
if len(self._prefetch_queue):
1158+
result = self._prefetch_queue.popleft().result()
1159+
else:
1160+
result = self._sample(batch_size)
11571161
while (
11581162
len(self._prefetch_queue)
11591163
< min(self._sampler._remaining_batches, self._prefetch_cap)
11601164
and not self._sampler.ran_out
1161-
) or not len(self._prefetch_queue):
1165+
):
11621166
fut = self._prefetch_executor.submit(self._sample, batch_size)
11631167
self._prefetch_queue.append(fut)
1164-
result = self._prefetch_queue.popleft().result()
11651168

11661169
if return_info:
11671170
out, info = result

0 commit comments

Comments
 (0)