Skip to content

Commit 6d6e9a4

Browse files
authored
local_sgd: initial version of fault tolerant LocalSGD (#47)
1 parent a484e4f commit 6d6e9a4

File tree

8 files changed

+323
-39
lines changed

8 files changed

+323
-39
lines changed

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ the entire training job.
1717
manager
1818
optim
1919
ddp
20+
local_sgd
2021
data
2122
checkpointing
2223
parameter_server

docs/source/local_sgd.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.. automodule:: torchft.local_sgd
2+
:members:
3+
:undoc-members:
4+
:show-inheritance:

torchft/ddp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce_grad(bucket.buffer())
71+
return state.allreduce(bucket.buffer())
7272

7373

7474
class PureDistributedDataParallel(nn.Module):
@@ -88,7 +88,7 @@ def __init__(self, manager: "Manager", module: nn.Module) -> None:
8888

8989
def post_grad_hook(p: torch.Tensor) -> None:
9090
if p.grad is not None:
91-
manager.allreduce_grad(p.grad)
91+
manager.allreduce(p.grad)
9292

9393
for p in module.parameters():
9494
p.register_post_accumulate_grad_hook(post_grad_hook)

torchft/ddp_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ def test_pure_ddp(self) -> None:
3232
for p in m.parameters():
3333
self.assertIsNotNone(p.grad)
3434

35-
self.assertEqual(manager.allreduce_grad.call_count, len(list(m.parameters())))
35+
self.assertEqual(manager.allreduce.call_count, len(list(m.parameters())))
3636

3737
def test_ddp(self) -> None:
3838
manager = create_autospec(Manager)
3939

4040
call_count = 0
4141

42-
def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]:
42+
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
4343
nonlocal call_count
4444

4545
call_count += 1
@@ -48,7 +48,7 @@ def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]:
4848
fut.set_result(tensor)
4949
return fut
5050

51-
manager.allreduce_grad = allreduce_grad
51+
manager.allreduce = allreduce
5252

5353
m = nn.Linear(3, 4)
5454
m = DistributedDataParallel(manager, m)

torchft/local_sgd.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
LocalSGD
9+
=========
10+
11+
This module implements a fault tolerant version of LocalSGD and related methods.
12+
"""
13+
14+
from typing import Any, Dict, List, Mapping, Optional
15+
16+
import torch
17+
from torch import nn, optim
18+
19+
from torchft.manager import Manager
20+
21+
22+
class LocalSGD(nn.Module):
23+
"""
24+
LocalSGD is a model wrapper similar to DistributedDataParallel that
25+
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26+
27+
This will synchronize the model parameters periodically in a fault tolerant
28+
way using a torchft Manager. The allreduce on the parameters will happen
29+
every sync_every steps after the optimizer.step call.
30+
31+
To implement safe and fault tolerant, this requires a backup copy of the
32+
weights. By default these are stored in CPU memory. If any error occurs
33+
during the LocalSGD step, the step will be discarded and the model
34+
parameters will reset back to the last time LocalSGD synchronized.
35+
36+
The backup weights could be eliminated by relaxing the guarantee of exactly
37+
`sync_every` steps but that would diverge from the LocalSGD algorithm.
38+
DiLoCo also needs this backup copy to compute the delta.
39+
40+
The torchft quorum is computed at the beginning of ``sync_every`` steps. If
41+
any error occurs, or a worker fails between syncs, ``sync_every`` steps will be
42+
discarded and a new quorum will be computed on the next step.
43+
44+
If running in async mode, on a joining worker the first ``sync_every`` steps
45+
will discarded as the model will be recovering during that period. When
46+
using sync mode, the checkpoint will be restored prior to the first step.
47+
48+
TODO: add a way via Manager to detect workers failing early for shrink only
49+
TODO: add DiLoCo support
50+
"""
51+
52+
def __init__(
53+
self,
54+
manager: Manager,
55+
model: nn.Module,
56+
optimizer: optim.Optimizer,
57+
sync_every: int,
58+
backup_device: Optional[torch.device] = None,
59+
pin_memory: bool = True,
60+
) -> None:
61+
"""
62+
Args:
63+
manager: The manager to use.
64+
model: The model to wrap.
65+
optimizer: The optimizer used by the model.
66+
sync_every: How often to sync the model weights.
67+
backup_device: The device to store the backup of the model parameters on. (default cpu)
68+
pin_memory: Whether to pin the memory used for the backup of the model parameters.
69+
"""
70+
super().__init__()
71+
72+
self._manager = manager
73+
self._model = model
74+
self._local_step = 0
75+
self._started_step = False
76+
self._sync_every = sync_every
77+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
78+
79+
device = backup_device or torch.device("cpu")
80+
81+
self._backup_parameters: Dict[str, torch.Tensor] = {}
82+
83+
for name, p in self._model.named_parameters():
84+
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
85+
if (
86+
pin_memory
87+
and t.device == torch.device("cpu")
88+
and torch.cuda.is_available()
89+
):
90+
t = t.pin_memory()
91+
self._backup_parameters[name] = t
92+
93+
# Need to copy the parameters to the host to be safe if we are on the first step.
94+
self._save_parameters()
95+
96+
optimizer.register_step_post_hook(self._step_post_hook)
97+
98+
def _save_parameters(self) -> None:
99+
# TODO: consider running copy on a separate stream
100+
for name, p in self._model.named_parameters():
101+
self._backup_parameters[name].copy_(p.data, non_blocking=True)
102+
103+
def _restore_parameters(self) -> None:
104+
# TODO: consider running copy on a separate stream
105+
for name, p in self._model.named_parameters():
106+
p.data.copy_(self._backup_parameters[name], non_blocking=True)
107+
108+
# pyre-fixme[14]: support state_dict args
109+
def state_dict(self) -> Dict[str, object]:
110+
"""
111+
state_dict returns the state_dict from the last time LocalSGD
112+
synchronized and not the current weights.
113+
"""
114+
state_dict = self._model.state_dict()
115+
for name, p in self._backup_parameters.items():
116+
assert name in state_dict
117+
state_dict[name] = p
118+
return state_dict
119+
120+
def load_state_dict(
121+
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
122+
) -> None:
123+
"""
124+
Loads the state dict to the model and the backup parameters.
125+
126+
This must be called while the model weights aren't being modified to
127+
avoid corrupting the backup weights.
128+
"""
129+
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
130+
self._save_parameters()
131+
132+
def forward(self, *args: object, **kwargs: object) -> object:
133+
"""
134+
Run the model parameters.
135+
136+
This should be called before the optimizer step.
137+
138+
This will start the quorum and save the parameters if this is the first step.
139+
"""
140+
if self._local_step == 0:
141+
self._manager.start_quorum()
142+
143+
self._started_step = True
144+
145+
return self._model.forward(*args, **kwargs)
146+
147+
def _step_post_hook(
148+
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
149+
) -> None:
150+
"""
151+
This hook is registered on the optimizer and is called after the optimizer step.
152+
153+
This will call the allreduce on the model weights every sync_every steps.
154+
If any errors occur it will restore to the weights from the previous sync.
155+
156+
``forward`` must be called before this function.
157+
"""
158+
assert self._started_step, "forward must be called before step"
159+
self._started_step = False
160+
161+
self._local_step += 1
162+
163+
if self._local_step >= self._sync_every:
164+
self._local_step = 0
165+
self._average()
166+
167+
if self._manager.should_commit():
168+
# save the parameters so we can restore from them later if necessary.
169+
self._save_parameters()
170+
else:
171+
# commit failed, restore from the backup parameters
172+
self._restore_parameters()
173+
174+
def _average(self) -> None:
175+
# TODO: do we need to broadcast buffers like DDP does?
176+
177+
works = []
178+
179+
for p in self._model.parameters():
180+
# TODO: bucketize parameters
181+
works.append(self._manager.allreduce(p))
182+
183+
for work in works:
184+
work.wait()

torchft/local_sgd_test.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
from unittest import TestCase
9+
from unittest.mock import create_autospec
10+
11+
import torch
12+
from torch import nn, optim
13+
14+
from torchft.local_sgd import LocalSGD
15+
from torchft.manager import Manager
16+
17+
18+
class SimpleModel(nn.Module):
19+
def __init__(self) -> None:
20+
super().__init__()
21+
22+
self.model = nn.Sequential(
23+
nn.Linear(3, 4),
24+
nn.ReLU(),
25+
nn.Linear(4, 5),
26+
nn.Sigmoid(),
27+
)
28+
29+
def forward(self, x: torch.Tensor) -> torch.Tensor:
30+
return self.model(x)
31+
32+
33+
def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]:
34+
return {name: p.data for name, p in m.named_parameters()}
35+
36+
37+
def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
38+
return {name: value.clone().detach() for name, value in state_dict.items()}
39+
40+
41+
class LocalSGDTest(TestCase):
42+
def test_local_sgd_healthy(self) -> None:
43+
base_m = SimpleModel()
44+
optimizer = optim.SGD(base_m.parameters())
45+
manager = create_autospec(Manager)
46+
47+
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
48+
self.assertEqual(m._local_step, 0)
49+
50+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
51+
52+
inp = torch.rand(2, 3)
53+
54+
loss = m(inp).mean()
55+
loss.backward()
56+
optimizer.step()
57+
58+
self.assertEqual(m._local_step, 1)
59+
self.assertEqual(manager.start_quorum.call_count, 1)
60+
61+
loss = m(inp).mean()
62+
loss.backward()
63+
optimizer.step()
64+
65+
manager.should_commit.return_value = True
66+
self.assertEqual(m._local_step, 0)
67+
68+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
69+
self.assertEqual(manager.should_commit.call_count, 1)
70+
self.assertEqual(manager.allreduce.call_count, 4)
71+
72+
def test_local_sgd_recovery(self) -> None:
73+
base_m = SimpleModel()
74+
optimizer = optim.SGD(base_m.parameters())
75+
manager = create_autospec(Manager)
76+
77+
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
78+
79+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
80+
og_state_dict = _copy_state_dict(base_m.state_dict())
81+
82+
inp = torch.rand(2, 3)
83+
84+
loss = m(inp).mean()
85+
loss.backward()
86+
optimizer.step()
87+
88+
self.assertEqual(m._local_step, 1)
89+
90+
state_dict = m.state_dict()
91+
torch.testing.assert_close(state_dict, m._backup_parameters)
92+
torch.testing.assert_close(state_dict, og_state_dict)
93+
94+
m.load_state_dict(state_dict)
95+
torch.testing.assert_close(_params_dict(base_m), state_dict)
96+
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))

0 commit comments

Comments
 (0)