You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Restructure the "Synchronize validation and test logging" section in
accelerator_prepare.rst into a problem-framing intro plus three
subsections (sync_dist, TorchMetrics, manual all_gather), a decision
table, and a common-pitfalls list.
Directly addresses the custom-metric case: accumulate per-step outputs,
call all_gather at epoch end, and compute the metric. The "my compute
runs N times" confusion is called out and resolved — after all_gather
every rank holds the same data, so the redundant compute is cheap and
correct; only self.log needs the rank_zero_only guard.
Refs #20117
Co-authored-by: Deependu <deependujha21@gmail.com>
@@ -77,12 +77,27 @@ See :ref:`replace-sampler-ddp` for more information.
77
77
Synchronize validation and test logging
78
78
***************************************
79
79
80
-
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
81
-
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
82
-
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
83
-
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
80
+
When running in distributed mode, each rank runs ``validation_step`` and ``test_step`` on its own
81
+
shard of the data. Without explicit synchronization, the value your logger persists is rank 0's
82
+
local value — computed on just ``1 / world_size`` of the validation or test set. That is the
83
+
metric your :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and
84
+
:class:`~lightning.pytorch.callbacks.EarlyStopping` callbacks see, so an unsynchronized metric
85
+
can silently pick the wrong checkpoint.
84
86
85
-
Note if you use any built in metrics or custom metrics that use `TorchMetrics <https://torchmetrics.readthedocs.io/>`_, these do not need to be updated and are automatically handled for you.
87
+
Lightning gives you three tools to fix this, and they are **not interchangeable**:
88
+
89
+
- ``sync_dist=True`` — mean-reduces a scalar across ranks. Correct only for averageable metrics.
90
+
- `TorchMetrics <https://torchmetrics.readthedocs.io/>`__ — syncs the metric's internal *state*, then computes. Correct for non-averageable metrics such as F1 or AUC.
91
+
- :meth:`~lightning.pytorch.core.LightningModule.all_gather` — gathers raw tensors across ranks so you can compute any reduction yourself.
92
+
93
+
Pick the lightest tool that fits the metric. If you accumulate per-step outputs and compute a
94
+
custom metric in ``on_validation_epoch_end`` (or ``on_test_epoch_end``), jump to
95
+
:ref:`manual-all-gather` — that is the pattern most DDP custom-metric questions come down to.
96
+
97
+
``sync_dist=True``
98
+
==================
99
+
100
+
The simplest option. Lightning mean-reduces each logged value across all ranks before storing it.
86
101
87
102
.. testcode::
88
103
@@ -101,31 +116,122 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics <h
101
116
# Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
0 commit comments