diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 324cf25cd9527..8267dae81e42e 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- fixed AccumulateGrad stream mismatch warning when using DDP with Fabric ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746)) + --- diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index af182ad7f422f..8c3d7b935d472 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,8 +124,20 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() + ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext() + # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if device_ids is not None: + capturing = torch.cuda.is_current_stream_capturing() + if capturing: + # DDP must be initialized on a side-stream for CUDA graph whole-network capture. + # The resulting AccumulateGrad stream mismatch is intentional in this case. + # See: https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) + else: + # Default stream avoids AccumulateGrad stream mismatch warnings during normal training. + ctx = torch.cuda.stream(torch.cuda.default_stream()) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f8e1134f1d101..5168e1a261dc8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623)) +- fixed AccumulateGrad stream mismatch warning when using DDP with Trainer ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746)) + - Fixed `LightningModule.toggle_optimizer` / `untoggle_optimizer` breaking under `torch.compile` by disabling Dynamo tracing on these bookkeeping helpers ([#21513](https://github.com/Lightning-AI/pytorch-lightning/issues/21513)) --- diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4eca6159ddced..8c75eae036865 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -189,9 +189,21 @@ def setup(self, trainer: "pl.Trainer") -> None: def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() + ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext() + log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if device_ids is not None: + capturing = torch.cuda.is_current_stream_capturing() + if capturing: + # DDP must be initialized on a side-stream for CUDA graph whole-network capture. + # The resulting AccumulateGrad stream mismatch is intentional in this case. + # See: https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) + else: + # Default stream avoids AccumulateGrad stream mismatch warnings during normal training. + ctx = torch.cuda.stream(torch.cuda.default_stream()) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index f302da5d1bc4f..455dd6bf2934c 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -146,6 +146,8 @@ def test_module_init_context(precision, expected_dtype): @mock.patch.dict(os.environ, {"LOCAL_RANK": "0"}) @mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel") @mock.patch("torch.cuda.Stream") +@mock.patch("torch.cuda.default_stream") +@mock.patch("torch.cuda.is_current_stream_capturing", return_value=False) @mock.patch("torch.cuda.stream") def test_setup_with_cuda_stream(cuda_stream_mock, *_): model = torch.nn.Linear(2, 2)