Skip to content

Commit c12258d

Browse files
author
royfa
committed
docs(fabric): add all-process note for save/load checkpoints
1 parent 8d86b24 commit c12258d

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

docs/source-fabric/api/fabric.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ Fabric
1616
:template: classtemplate.rst
1717

1818
Fabric
19+
20+
.. note::
21+
22+
In distributed training, :meth:`~lightning.fabric.fabric.Fabric.save` and
23+
:meth:`~lightning.fabric.fabric.Fabric.load` must be called on all processes.

docs/source-fabric/guide/checkpoint/checkpoint.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ To save the state to the filesystem, pass it to the :meth:`~lightning.fabric.fab
4343
4444
fabric.save("path/to/checkpoint.ckpt", state)
4545
46+
This method must be called on all processes.
47+
4648
This will unwrap your model and optimizer and automatically convert their ``state_dict`` for you.
4749
Fabric and the underlying strategy will decide in which format your checkpoint gets saved.
4850
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` :doc:`saves multiple files from all ranks <distributed_checkpoint>`.
@@ -64,6 +66,8 @@ You can restore the state by loading a saved checkpoint back with :meth:`~lightn
6466
6567
fabric.load("path/to/checkpoint.ckpt", state)
6668
69+
This method must be called on all processes.
70+
6771
Fabric will replace the state of your objects in-place.
6872
You can also request only to restore a portion of the checkpoint.
6973
For example, you want only to restore the model weights in your inference script:

docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ The distributed checkpoint format is the default when you train with the :doc:`F
4545
# DON'T do this (inefficient):
4646
# torch.save("path/to/checkpoint/file", state)
4747
48+
This method must be called on all processes.
49+
4850
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
4951
This reduces memory peaks and speeds up the saving to disk.
5052

@@ -138,6 +140,8 @@ You can easily load a distributed checkpoint in Fabric if your script uses :doc:
138140
# DON'T do this (inefficient):
139141
# model.load_state_dict(torch.load("path/to/checkpoint/file"))
140142
143+
This method must be called on all processes.
144+
141145
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.
142146

143147
.. collapse:: Full example

0 commit comments

Comments
 (0)