Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- fixed AccumulateGrad stream mismatch warning when using DDP with Fabric ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))


---

Expand Down
14 changes: 13 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,20 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()

# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
if device_ids is not None:
capturing = torch.cuda.is_current_stream_capturing()
if capturing:
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
# The resulting AccumulateGrad stream mismatch is intentional in this case.
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream())
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
else:
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
ctx = torch.cuda.stream(torch.cuda.default_stream())
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))

- fixed AccumulateGrad stream mismatch warning when using DDP with Trainer ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))

---

## [2.6.4] - 2026-05-20
Expand Down
14 changes: 13 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,21 @@ def setup(self, trainer: "pl.Trainer") -> None:
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()

log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
if device_ids is not None:
capturing = torch.cuda.is_current_stream_capturing()
if capturing:
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
# The resulting AccumulateGrad stream mismatch is intentional in this case.
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream())
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
else:
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
ctx = torch.cuda.stream(torch.cuda.default_stream())
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def test_module_init_context(precision, expected_dtype):
@mock.patch.dict(os.environ, {"LOCAL_RANK": "0"})
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
@mock.patch("torch.cuda.Stream")
@mock.patch("torch.cuda.default_stream")
@mock.patch("torch.cuda.is_current_stream_capturing", return_value=False)
@mock.patch("torch.cuda.stream")
def test_setup_with_cuda_stream(cuda_stream_mock, *_):
model = torch.nn.Linear(2, 2)
Expand Down
Loading