Skip to content

Commit 932b7e3

Browse files
c-pozzideependujha
andauthored
docs: expand DDP metric synchronization guidance (#21685)
Restructure the "Synchronize validation and test logging" section in accelerator_prepare.rst into a problem-framing intro plus three subsections (sync_dist, TorchMetrics, manual all_gather), a decision table, and a common-pitfalls list. Directly addresses the custom-metric case: accumulate per-step outputs, call all_gather at epoch end, and compute the metric. The "my compute runs N times" confusion is called out and resolved — after all_gather every rank holds the same data, so the redundant compute is cheap and correct; only self.log needs the rank_zero_only guard. Refs #20117 Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent a8d32b0 commit 932b7e3

1 file changed

Lines changed: 128 additions & 22 deletions

File tree

docs/source-pytorch/accelerators/accelerator_prepare.rst

Lines changed: 128 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,27 @@ See :ref:`replace-sampler-ddp` for more information.
7777
Synchronize validation and test logging
7878
***************************************
7979

80-
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
81-
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
82-
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
83-
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
80+
When running in distributed mode, each rank runs ``validation_step`` and ``test_step`` on its own
81+
shard of the data. Without explicit synchronization, the value your logger persists is rank 0's
82+
local value — computed on just ``1 / world_size`` of the validation or test set. That is the
83+
metric your :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and
84+
:class:`~lightning.pytorch.callbacks.EarlyStopping` callbacks see, so an unsynchronized metric
85+
can silently pick the wrong checkpoint.
8486

85-
Note if you use any built in metrics or custom metrics that use `TorchMetrics <https://torchmetrics.readthedocs.io/>`_, these do not need to be updated and are automatically handled for you.
87+
Lightning gives you three tools to fix this, and they are **not interchangeable**:
88+
89+
- ``sync_dist=True`` — mean-reduces a scalar across ranks. Correct only for averageable metrics.
90+
- `TorchMetrics <https://torchmetrics.readthedocs.io/>`__ — syncs the metric's internal *state*, then computes. Correct for non-averageable metrics such as F1 or AUC.
91+
- :meth:`~lightning.pytorch.core.LightningModule.all_gather` — gathers raw tensors across ranks so you can compute any reduction yourself.
92+
93+
Pick the lightest tool that fits the metric. If you accumulate per-step outputs and compute a
94+
custom metric in ``on_validation_epoch_end`` (or ``on_test_epoch_end``), jump to
95+
:ref:`manual-all-gather` — that is the pattern most DDP custom-metric questions come down to.
96+
97+
``sync_dist=True``
98+
==================
99+
100+
The simplest option. Lightning mean-reduces each logged value across all ranks before storing it.
86101

87102
.. testcode::
88103

@@ -101,31 +116,122 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics <h
101116
# Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
102117
self.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
103118

104-
It is possible to perform some computation manually and log the reduced result on rank 0 as follows:
119+
The ``sync_dist`` option can also be used in logging calls during the training step, but be aware
120+
that this can lead to significant communication overhead and slow down your training.
121+
122+
.. warning::
123+
``sync_dist=True`` averages per-rank *values*. It is only correct when
124+
``mean(per_rank_metric) == global_metric``. It is **wrong** for F1, AUC, and precision or
125+
recall on imbalanced classes — the mean of per-rank F1 scores is not the global F1 score.
126+
For those metrics, reach for TorchMetrics instead.
127+
128+
TorchMetrics
129+
============
130+
131+
`TorchMetrics <https://torchmetrics.readthedocs.io/>`__ handles the non-averageable case by
132+
syncing the metric's internal *state* (for example, the running counts of true and false
133+
positives) across ranks, then computing the metric from the merged state. The result matches
134+
what you would get by evaluating on one rank with the full dataset. No ``sync_dist`` flag is
135+
needed; the metric synchronizes itself when it is logged.
105136

106137
.. code-block:: python
107138
108-
def __init__(self):
109-
super().__init__()
110-
self.outputs = []
139+
from torchmetrics.classification import BinaryF1Score
111140
112141
113-
def test_step(self, batch, batch_idx):
114-
x, y = batch
115-
tensors = self(x)
116-
self.outputs.append(tensors)
117-
return tensors
142+
class LitModel(LightningModule):
143+
def __init__(self):
144+
super().__init__()
145+
self.val_f1 = BinaryF1Score()
146+
147+
148+
def validation_step(self, batch, batch_idx):
149+
x, y = batch
150+
logits = self(x)
151+
self.val_f1.update(logits, y)
152+
# Passing the metric object to self.log triggers DDP sync at epoch end.
153+
self.log("val_f1", self.val_f1, on_epoch=True)
154+
155+
This is the recommended option for any classification, retrieval, or ranking metric.
156+
157+
.. _manual-all-gather:
118158

159+
Manual ``all_gather``
160+
=====================
119161

120-
def on_test_epoch_end(self):
121-
mean = torch.mean(self.all_gather(self.outputs))
122-
self.outputs.clear() # free memory
162+
Use this when your metric is a custom computation over outputs accumulated across the whole
163+
epoch — the case where neither ``sync_dist`` nor TorchMetrics fits. The pattern is: accumulate
164+
per-step outputs into a list on the module, then at epoch end call
165+
:meth:`~lightning.pytorch.core.LightningModule.all_gather` to combine each rank's contributions
166+
before computing the metric. ``all_gather`` returns a tensor of shape
167+
``[world_size, *tensor_shape]`` and every rank receives the same result.
123168

124-
# When you call `self.log` only on rank 0, don't forget to add
125-
# `rank_zero_only=True` to avoid deadlocks on synchronization.
126-
# Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/pytorch-lightning/issues/15852
127-
if self.trainer.is_global_zero:
128-
self.log("my_reduced_metric", mean, rank_zero_only=True)
169+
.. code-block:: python
170+
171+
class LitModel(LightningModule):
172+
def __init__(self):
173+
super().__init__()
174+
self.val_outputs = []
175+
176+
177+
def validation_step(self, batch, batch_idx):
178+
x, y = batch
179+
predictions = self(x)
180+
self.val_outputs.append(predictions)
181+
return predictions
182+
183+
184+
def on_validation_epoch_end(self):
185+
# self.all_gather returns a tensor of shape [world_size, *tensor_shape] on every rank.
186+
gathered = self.all_gather(self.val_outputs)
187+
metric = my_custom_metric(gathered)
188+
self.val_outputs.clear() # free memory before the next epoch
189+
190+
# When you call `self.log` only on rank 0, don't forget to add
191+
# `rank_zero_only=True` to avoid deadlocks on synchronization.
192+
# Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/pytorch-lightning/issues/15852
193+
if self.trainer.is_global_zero:
194+
self.log("my_custom_val_metric", metric, rank_zero_only=True)
195+
196+
The same pattern applies to ``test_step`` / ``on_test_epoch_end``.
197+
198+
A common source of confusion here is that ``on_validation_epoch_end`` runs on every rank, so at
199+
first glance the metric looks like it is being computed ``world_size`` times. After
200+
``all_gather`` every rank already holds the *same* gathered tensor, so every rank computes the
201+
*same* value — the redundant work is cheap and the result is correct. The ``is_global_zero``
202+
guard belongs around ``self.log``, not around the computation. Never guard ``all_gather``
203+
itself with ``is_global_zero``: it is a collective, and if some ranks skip it the program will
204+
hang.
205+
206+
Which one should I use?
207+
=======================
208+
209+
.. list-table::
210+
:header-rows: 1
211+
:widths: 45 55
212+
213+
* - Metric
214+
- Use
215+
* - Averageable scalar (loss, accuracy, MSE)
216+
- ``sync_dist=True``
217+
* - Classification or ranking metric (F1, AUC, precision, recall)
218+
- TorchMetrics
219+
* - Custom reduction over gathered tensors
220+
- ``self.all_gather()``
221+
222+
Common pitfalls
223+
===============
224+
225+
- **Using** ``sync_dist=True`` **on a non-averageable metric.** The logged value is the mean of
226+
per-rank metrics, which is not the global metric. Use TorchMetrics instead.
227+
- **Guarding** ``all_gather`` **with** ``is_global_zero``. Collectives must be called on every
228+
rank. Put the guard around ``self.log``, not around the gather.
229+
- **Passing** ``rank_zero_only=True`` **to** ``self.log`` **without synchronizing first.** Rank 0
230+
logs its local value only, which is the ``1 / world_size`` problem this section opens with.
231+
232+
See also: the `TorchMetrics distributed evaluation guide
233+
<https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metrics-and-distributed-training-ddp>`_
234+
for how TorchMetrics synchronizes state internally.
129235

130236

131237
----

0 commit comments

Comments
 (0)