Skip to content

Commit 35e56ef

Browse files
authored
fix: AccumulateGrad stream mismatch warning when using DDP with Fabric & Trainer (#21746)
* update * update * update * update * update * update * update * update * Apply suggestion from @deependujha
1 parent 2849907 commit 35e56ef

5 files changed

Lines changed: 32 additions & 3 deletions

File tree

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
### Fixed
2222

23-
-
23+
- fixed AccumulateGrad stream mismatch warning when using DDP with Fabric ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))
24+
2425

2526
---
2627

src/lightning/fabric/strategies/ddp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,20 @@ def setup_environment(self) -> None:
124124
def setup_module(self, module: Module) -> DistributedDataParallel:
125125
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
126126
device_ids = self._determine_ddp_device_ids()
127+
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()
128+
127129
# https://pytorch.org/docs/stable/notes/cuda.html#id5
128-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
130+
if device_ids is not None:
131+
capturing = torch.cuda.is_current_stream_capturing()
132+
if capturing:
133+
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
134+
# The resulting AccumulateGrad stream mismatch is intentional in this case.
135+
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
136+
ctx = torch.cuda.stream(torch.cuda.Stream())
137+
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
138+
else:
139+
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
140+
ctx = torch.cuda.stream(torch.cuda.default_stream())
129141
with ctx:
130142
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
131143

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))
3030

31+
- fixed AccumulateGrad stream mismatch warning when using DDP with Trainer ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))
32+
3133
- Fixed `LightningModule.toggle_optimizer` / `untoggle_optimizer` breaking under `torch.compile` by disabling Dynamo tracing on these bookkeeping helpers ([#21513](https://github.com/Lightning-AI/pytorch-lightning/issues/21513))
3234

3335
---

src/lightning/pytorch/strategies/ddp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,21 @@ def setup(self, trainer: "pl.Trainer") -> None:
189189
def _setup_model(self, model: Module) -> DistributedDataParallel:
190190
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
191191
device_ids = self.determine_ddp_device_ids()
192+
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()
193+
192194
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
193195
# https://pytorch.org/docs/stable/notes/cuda.html#id5
194-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
196+
if device_ids is not None:
197+
capturing = torch.cuda.is_current_stream_capturing()
198+
if capturing:
199+
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
200+
# The resulting AccumulateGrad stream mismatch is intentional in this case.
201+
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
202+
ctx = torch.cuda.stream(torch.cuda.Stream())
203+
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
204+
else:
205+
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
206+
ctx = torch.cuda.stream(torch.cuda.default_stream())
195207
with ctx:
196208
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
197209

tests/tests_fabric/strategies/test_ddp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def test_module_init_context(precision, expected_dtype):
146146
@mock.patch.dict(os.environ, {"LOCAL_RANK": "0"})
147147
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
148148
@mock.patch("torch.cuda.Stream")
149+
@mock.patch("torch.cuda.default_stream")
150+
@mock.patch("torch.cuda.is_current_stream_capturing", return_value=False)
149151
@mock.patch("torch.cuda.stream")
150152
def test_setup_with_cuda_stream(cuda_stream_mock, *_):
151153
model = torch.nn.Linear(2, 2)

0 commit comments

Comments
 (0)