|
15 | 15 | from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy |
16 | 16 | from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher |
17 | 17 | from lightning.fabric.utilities.distributed import ( |
| 18 | + DistributedSamplerWrapper, |
18 | 19 | _destroy_dist_connection, |
19 | 20 | _gather_all_tensors, |
20 | 21 | _get_default_process_group_backend_for_device, |
@@ -274,3 +275,73 @@ def test_is_dtensor(monkeypatch): |
274 | 275 |
|
275 | 276 | monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False) |
276 | 277 | assert not _is_dtensor(Mock(spec=DTensor)) |
| 278 | + |
| 279 | + |
| 280 | +class _CustomSampler(torch.utils.data.Sampler): |
| 281 | + """A custom sampler for testing DistributedSamplerWrapper.""" |
| 282 | + |
| 283 | + def __init__(self, data_source, non_callable_set_epoch: bool = False): |
| 284 | + self.data_source = data_source |
| 285 | + if non_callable_set_epoch: |
| 286 | + self.set_epoch = "not a method" # attribute exists but is not callable |
| 287 | + |
| 288 | + def __len__(self): |
| 289 | + return len(self.data_source) |
| 290 | + |
| 291 | + def __iter__(self): |
| 292 | + return iter(range(len(self.data_source))) |
| 293 | + |
| 294 | + |
| 295 | +class _CustomSamplerWithSetEpoch(_CustomSampler): |
| 296 | + """A custom sampler that tracks set_epoch calls for testing.""" |
| 297 | + |
| 298 | + def __init__(self, data_source): |
| 299 | + super().__init__(data_source) |
| 300 | + self.epoch = 0 |
| 301 | + self.set_epoch_call_count = 0 |
| 302 | + |
| 303 | + def set_epoch(self, epoch): |
| 304 | + self.epoch = epoch |
| 305 | + self.set_epoch_call_count += 1 |
| 306 | + |
| 307 | + |
| 308 | +def test_distributed_sampler_wrapper_set_epoch(): |
| 309 | + """Test that DistributedSamplerWrapper correctly handles set_epoch for various sampler types. |
| 310 | +
|
| 311 | + Reproduces issue #21454: When a sampler is wrapped by DistributedSamplerWrapper, calling set_epoch on the wrapper |
| 312 | + should forward the call to the underlying sampler if it supports the method. |
| 313 | +
|
| 314 | + """ |
| 315 | + data_source = list(range(100)) |
| 316 | + |
| 317 | + # Case 1: Sampler WITH set_epoch method - should forward the call |
| 318 | + sampler_with_set_epoch = _CustomSamplerWithSetEpoch(data_source) |
| 319 | + wrapper = DistributedSamplerWrapper(sampler_with_set_epoch, num_replicas=2, rank=0) |
| 320 | + |
| 321 | + assert sampler_with_set_epoch.epoch == 0 |
| 322 | + assert sampler_with_set_epoch.set_epoch_call_count == 0 |
| 323 | + |
| 324 | + wrapper.set_epoch(5) |
| 325 | + assert wrapper.epoch == 5 |
| 326 | + assert sampler_with_set_epoch.epoch == 5, "set_epoch was not forwarded to the underlying sampler" |
| 327 | + assert sampler_with_set_epoch.set_epoch_call_count == 1 |
| 328 | + |
| 329 | + wrapper.set_epoch(10) |
| 330 | + assert wrapper.epoch == 10 |
| 331 | + assert sampler_with_set_epoch.epoch == 10 |
| 332 | + assert sampler_with_set_epoch.set_epoch_call_count == 2 |
| 333 | + |
| 334 | + # Case 2: Sampler WITHOUT set_epoch method - should not fail |
| 335 | + sampler_without_set_epoch = _CustomSampler(data_source) |
| 336 | + wrapper = DistributedSamplerWrapper(sampler_without_set_epoch, num_replicas=2, rank=0) |
| 337 | + |
| 338 | + wrapper.set_epoch(5) # Should not raise |
| 339 | + assert wrapper.epoch == 5 |
| 340 | + |
| 341 | + # Case 3: Sampler with non-callable set_epoch attribute - should not fail or call it |
| 342 | + sampler_non_callable = _CustomSampler(data_source, non_callable_set_epoch=True) |
| 343 | + wrapper = DistributedSamplerWrapper(sampler_non_callable, num_replicas=2, rank=0) |
| 344 | + |
| 345 | + wrapper.set_epoch(5) # Should not raise |
| 346 | + assert wrapper.epoch == 5 |
| 347 | + assert sampler_non_callable.set_epoch == "not a method" # Should remain unchanged |
0 commit comments