Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- fixed AccumulateGrad stream mismatch warning when using DDP with Fabric ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))


---

Expand Down
14 changes: 13 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,20 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()

# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
if device_ids is not None:
capturing = torch.cuda.is_current_stream_capturing()
if capturing:
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
# The resulting AccumulateGrad stream mismatch is intentional in this case.
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream())
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
else:
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
ctx = torch.cuda.stream(torch.cuda.default_stream())
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))

- fixed AccumulateGrad stream mismatch warning when using DDP with Trainer ([#21746](https://github.com/Lightning-AI/pytorch-lightning/pull/21746))

- Fixed `LightningModule.toggle_optimizer` / `untoggle_optimizer` breaking under `torch.compile` by disabling Dynamo tracing on these bookkeeping helpers ([#21513](https://github.com/Lightning-AI/pytorch-lightning/issues/21513))

---
Expand Down
14 changes: 13 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,21 @@ def setup(self, trainer: "pl.Trainer") -> None:
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
ctx: Union[torch.cuda.StreamContext, nullcontext] = nullcontext()

log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
if device_ids is not None:
capturing = torch.cuda.is_current_stream_capturing()
if capturing:
# DDP must be initialized on a side-stream for CUDA graph whole-network capture.
# The resulting AccumulateGrad stream mismatch is intentional in this case.
# See: https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream())
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False)
else:
# Default stream avoids AccumulateGrad stream mismatch warnings during normal training.
ctx = torch.cuda.stream(torch.cuda.default_stream())
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def test_module_init_context(precision, expected_dtype):
@mock.patch.dict(os.environ, {"LOCAL_RANK": "0"})
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
@mock.patch("torch.cuda.Stream")
@mock.patch("torch.cuda.default_stream")
@mock.patch("torch.cuda.is_current_stream_capturing", return_value=False)
@mock.patch("torch.cuda.stream")
def test_setup_with_cuda_stream(cuda_stream_mock, *_):
model = torch.nn.Linear(2, 2)
Expand Down
Loading