Skip to content

Commit 79a39c0

Browse files
fix(fabric): forward set_epoch to underlying sampler in DistributedSamplerWrapper (#21456)
The DistributedSamplerWrapper now forwards set_epoch() calls to the underlying sampler if it supports the method. This fix is generic and works for any sampler subclass that implements set_epoch(), not just specific implementations. This is important for samplers that use the epoch for shuffling or other epoch-dependent behavior in distributed training. Fixes #21454 Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent aa0ee0d commit 79a39c0

3 files changed

Lines changed: 81 additions & 0 deletions

File tree

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
### Fixed
1919

20+
- Fixed `DistributedSamplerWrapper` not forwarding `set_epoch` to the underlying sampler ([#21454](https://github.com/Lightning-AI/pytorch-lightning/pull/21454))
21+
2022
- Fixed DDP notebook CUDA fork check to allow passive initialization when CUDA is not actively used ([#21402](https://github.com/Lightning-AI/pytorch-lightning/pull/21402))
2123

2224
### Removed

src/lightning/fabric/utilities/distributed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,14 @@ def __iter__(self) -> Iterator:
372372
self.dataset.reset()
373373
return (self.dataset[index] for index in super().__iter__())
374374

375+
@override
376+
def set_epoch(self, epoch: int) -> None:
377+
super().set_epoch(epoch)
378+
# Forward set_epoch to the original sampler if it supports it
379+
original_sampler = self.dataset._sampler
380+
if hasattr(original_sampler, "set_epoch") and callable(original_sampler.set_epoch):
381+
original_sampler.set_epoch(epoch)
382+
375383

376384
def _suggested_max_num_threads(num_processes: int = 1) -> int:
377385
if num_processes < 1:

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
1616
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
1717
from lightning.fabric.utilities.distributed import (
18+
DistributedSamplerWrapper,
1819
_destroy_dist_connection,
1920
_gather_all_tensors,
2021
_get_default_process_group_backend_for_device,
@@ -274,3 +275,73 @@ def test_is_dtensor(monkeypatch):
274275

275276
monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
276277
assert not _is_dtensor(Mock(spec=DTensor))
278+
279+
280+
class _CustomSampler(torch.utils.data.Sampler):
281+
"""A custom sampler for testing DistributedSamplerWrapper."""
282+
283+
def __init__(self, data_source, non_callable_set_epoch: bool = False):
284+
self.data_source = data_source
285+
if non_callable_set_epoch:
286+
self.set_epoch = "not a method" # attribute exists but is not callable
287+
288+
def __len__(self):
289+
return len(self.data_source)
290+
291+
def __iter__(self):
292+
return iter(range(len(self.data_source)))
293+
294+
295+
class _CustomSamplerWithSetEpoch(_CustomSampler):
296+
"""A custom sampler that tracks set_epoch calls for testing."""
297+
298+
def __init__(self, data_source):
299+
super().__init__(data_source)
300+
self.epoch = 0
301+
self.set_epoch_call_count = 0
302+
303+
def set_epoch(self, epoch):
304+
self.epoch = epoch
305+
self.set_epoch_call_count += 1
306+
307+
308+
def test_distributed_sampler_wrapper_set_epoch():
309+
"""Test that DistributedSamplerWrapper correctly handles set_epoch for various sampler types.
310+
311+
Reproduces issue #21454: When a sampler is wrapped by DistributedSamplerWrapper, calling set_epoch on the wrapper
312+
should forward the call to the underlying sampler if it supports the method.
313+
314+
"""
315+
data_source = list(range(100))
316+
317+
# Case 1: Sampler WITH set_epoch method - should forward the call
318+
sampler_with_set_epoch = _CustomSamplerWithSetEpoch(data_source)
319+
wrapper = DistributedSamplerWrapper(sampler_with_set_epoch, num_replicas=2, rank=0)
320+
321+
assert sampler_with_set_epoch.epoch == 0
322+
assert sampler_with_set_epoch.set_epoch_call_count == 0
323+
324+
wrapper.set_epoch(5)
325+
assert wrapper.epoch == 5
326+
assert sampler_with_set_epoch.epoch == 5, "set_epoch was not forwarded to the underlying sampler"
327+
assert sampler_with_set_epoch.set_epoch_call_count == 1
328+
329+
wrapper.set_epoch(10)
330+
assert wrapper.epoch == 10
331+
assert sampler_with_set_epoch.epoch == 10
332+
assert sampler_with_set_epoch.set_epoch_call_count == 2
333+
334+
# Case 2: Sampler WITHOUT set_epoch method - should not fail
335+
sampler_without_set_epoch = _CustomSampler(data_source)
336+
wrapper = DistributedSamplerWrapper(sampler_without_set_epoch, num_replicas=2, rank=0)
337+
338+
wrapper.set_epoch(5) # Should not raise
339+
assert wrapper.epoch == 5
340+
341+
# Case 3: Sampler with non-callable set_epoch attribute - should not fail or call it
342+
sampler_non_callable = _CustomSampler(data_source, non_callable_set_epoch=True)
343+
wrapper = DistributedSamplerWrapper(sampler_non_callable, num_replicas=2, rank=0)
344+
345+
wrapper.set_epoch(5) # Should not raise
346+
assert wrapper.epoch == 5
347+
assert sampler_non_callable.set_epoch == "not a method" # Should remain unchanged

0 commit comments

Comments
 (0)