3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
-
7
6
"""
8
7
LocalSGD
9
8
=========
10
-
11
9
This module implements a fault tolerant version of LocalSGD and related methods.
12
10
"""
13
-
14
- from typing import Any , Dict , List , Mapping , Optional
11
+ import logging
12
+ from types import TracebackType
13
+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
15
14
16
15
import torch
17
16
from torch import nn , optim
17
+ from torch .nn .parameter import Parameter
18
+ from torch .optim .optimizer import Optimizer
19
+ from torch .utils .hooks import RemovableHandle
18
20
19
21
from torchft .manager import Manager
20
22
23
+ logger : logging .Logger = logging .getLogger (__name__ )
24
+
21
25
22
- class LocalSGD ( nn . Module ) :
26
+ class LocalSGD :
23
27
"""
24
- LocalSGD is a model wrapper similar to DistributedDataParallel that
28
+ LocalSGD is a context manager that
25
29
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26
30
27
31
This will synchronize the model parameters periodically in a fault tolerant
@@ -60,26 +64,22 @@ def __init__(
60
64
) -> None :
61
65
"""
62
66
Args:
63
- manager: The manager to use.
64
- model: The model to wrap.
65
- optimizer: The optimizer used by the model.
66
- sync_every: How often to sync the model weights.
67
- backup_device: The device to store the backup of the model parameters on. (default cpu)
68
- pin_memory: Whether to pin the memory used for the backup of the model parameters.
67
+ manager (Manager) : The manager to use.
68
+ model (nn.Module) : The model to wrap.
69
+ optimizer (optim.Optimizer) : The optimizer used by the model.
70
+ sync_every (int) : How often to sync the model weights.
71
+ backup_device (Optional[torch.device]) : The device to store the backup of the model parameters on. (default cpu)
72
+ pin_memory (bool) : Whether to pin the memory used for the backup of the model parameters.
69
73
"""
70
74
super ().__init__ ()
71
-
72
75
self ._manager = manager
73
76
self ._model = model
77
+ self ._local_optimizer = optimizer
74
78
self ._local_step = 0
75
- self ._started_step = False
76
79
self ._sync_every = sync_every
77
80
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
78
-
79
81
device = backup_device or torch .device ("cpu" )
80
-
81
82
self ._backup_parameters : Dict [str , torch .Tensor ] = {}
82
-
83
83
for name , p in self ._model .named_parameters ():
84
84
t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
85
85
if (
@@ -90,86 +90,88 @@ def __init__(
90
90
t = t .pin_memory ()
91
91
self ._backup_parameters [name ] = t
92
92
93
+ self ._hooks : List [RemovableHandle ] = []
93
94
# Need to copy the parameters to the host to be safe if we are on the first step.
94
95
self ._save_parameters ()
95
96
96
- optimizer .register_step_post_hook (self ._step_post_hook )
97
+ def __enter__ (self ) -> "LocalSGD" :
98
+ # Add optimizer hook which increments the local step counter and syncs if necessary
99
+ self ._hooks .append (
100
+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
101
+ )
102
+ # Register a forward prehook to check for quorum
103
+ self ._hooks .append (
104
+ self ._model .register_forward_pre_hook (self ._forward_step_pre_hook )
105
+ )
106
+ return self
107
+
108
+ def __exit__ (
109
+ self ,
110
+ exc_type : Optional [Type [BaseException ]],
111
+ exc_value : Optional [BaseException ],
112
+ traceback : Optional [TracebackType ],
113
+ ) -> bool :
114
+ # Handle any cleanup or error handling here
115
+ if exc_type is not None :
116
+ # If an exception occurred, restore parameters
117
+ self ._restore_parameters ()
118
+ # Clean up hooks
119
+ for hook in self ._hooks :
120
+ hook .remove ()
121
+ self ._hooks .clear ()
122
+
123
+ return False # Propagate exceptions
97
124
98
125
def _save_parameters (self ) -> None :
99
- # TODO: consider running copy on a separate stream
100
- for name , p in self ._model .named_parameters ():
101
- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
126
+ with torch .no_grad ():
127
+ # TODO: consider running copy on a separate stream
128
+ for name , p in self ._model .named_parameters ():
129
+ self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
102
130
103
131
def _restore_parameters (self ) -> None :
104
- # TODO: consider running copy on a separate stream
105
- for name , p in self ._model .named_parameters ():
106
- p .data .copy_ (self ._backup_parameters [name ], non_blocking = True )
132
+ with torch .no_grad ():
133
+ # TODO: consider running copy on a separate stream
134
+ for name , p in self ._model .named_parameters ():
135
+ p .copy_ (self ._backup_parameters [name ], non_blocking = False )
107
136
108
- # pyre-fixme[14]: support state_dict args
109
- def state_dict (self ) -> Dict [str , object ]:
110
- """
111
- state_dict returns the state_dict from the last time LocalSGD
112
- synchronized and not the current weights.
113
- """
114
- state_dict = self ._model .state_dict ()
115
- for name , p in self ._backup_parameters .items ():
116
- assert name in state_dict
117
- state_dict [name ] = p
118
- return state_dict
119
-
120
- def load_state_dict (
121
- self , state_dict : Mapping [str , Any ], strict : bool = True , assign : bool = False
137
+ def _step_post_hook (
138
+ self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
122
139
) -> None :
123
140
"""
124
- Loads the state dict to the model and the backup parameters.
125
-
126
- This must be called while the model weights aren't being modified to
127
- avoid corrupting the backup weights.
141
+ This hook is registered on the optimizer and is called after the optimizer step.
128
142
"""
129
- self ._model .load_state_dict (state_dict , strict = strict , assign = assign )
130
- self ._save_parameters ()
143
+ self ._local_step += 1
144
+ if self ._local_step >= self ._sync_every :
145
+ self .sync ()
131
146
132
- def forward (self , * args : object , ** kwargs : object ) -> object :
147
+ def _forward_step_pre_hook (self , _module : nn . Module , _args : List [ object ] ) -> None :
133
148
"""
134
- Run the model parameters.
135
-
136
- This should be called before the optimizer step.
137
-
138
- This will start the quorum and save the parameters if this is the first step.
149
+ Start the quorum before each module forward.
139
150
"""
140
151
if self ._local_step == 0 :
141
152
self ._manager .start_quorum ()
142
153
143
- self ._started_step = True
144
-
145
- return self ._model .forward (* args , ** kwargs )
146
-
147
- def _step_post_hook (
148
- self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
149
- ) -> None :
154
+ def sync (self ) -> None :
150
155
"""
151
- This hook is registered on the optimizer and is called after the optimizer step.
152
-
153
- This will call the allreduce on the model weights every sync_every steps.
154
- If any errors occur it will restore to the weights from the previous sync.
155
-
156
- ``forward`` must be called before this function.
156
+ Synchronizes and averages the model weights across the manager.
157
157
"""
158
- assert self ._started_step , "forward must be called before step"
159
- self ._started_step = False
158
+ self ._perform_sync ()
160
159
161
- self ._local_step += 1
160
+ if self ._manager .should_commit ():
161
+ self ._save_parameters ()
162
+ else :
163
+ # commit failed, restore from the backup parameters
164
+ self ._restore_parameters ()
162
165
163
- if self ._local_step >= self ._sync_every :
164
- self ._local_step = 0
165
- self ._average ()
166
+ self ._local_step = 0
166
167
167
- if self ._manager .should_commit ():
168
- # save the parameters so we can restore from them later if necessary.
169
- self ._save_parameters ()
170
- else :
171
- # commit failed, restore from the backup parameters
172
- self ._restore_parameters ()
168
+ def _perform_sync (self ) -> None :
169
+ """
170
+ Performs the synchronization of the model weights across the manager.
171
+ This method is intended to be overridden by subclasses to implement custom
172
+ synchronization logic.
173
+ """
174
+ self ._average ()
173
175
174
176
def _average (self ) -> None :
175
177
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +184,67 @@ def _average(self) -> None:
182
184
183
185
for work in works :
184
186
work .wait ()
187
+
188
+
189
+ class DiLoCo (LocalSGD ):
190
+ """
191
+ DiLoCo is a subclass of LocalSGD that overrides the synchronization
192
+ mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
193
+
194
+ diloco: https://arxiv.org/pdf/2311.08105
195
+ """
196
+
197
+ def __init__ (
198
+ self ,
199
+ manager : Manager ,
200
+ model : nn .Module ,
201
+ inner_optimizer : optim .Optimizer ,
202
+ outer_optimizer : optim .Optimizer ,
203
+ sync_every : int ,
204
+ backup_device : Optional [torch .device ] = None ,
205
+ pin_memory : bool = True ,
206
+ ) -> None :
207
+ if manager ._use_async_quorum :
208
+ raise ValueError (
209
+ "Using DiLoCo require synchronous quorum to be enabled. "
210
+ "Ensure that the manager is initialized with use_async_quorum=False"
211
+ )
212
+ super ().__init__ (
213
+ manager , model , inner_optimizer , sync_every , backup_device , pin_memory
214
+ )
215
+ self ._outer_optimizer = outer_optimizer
216
+
217
+ def _perform_sync (self ) -> None :
218
+ """
219
+ Overrides the sync method to calculate the pseugradient, average them across the manager group, and
220
+ step using the outer optimizer.
221
+ """
222
+
223
+ # Set the .grad field of each parameter to its pseudogradient
224
+ for name , p in self ._model .named_parameters ():
225
+ assert name in self ._backup_parameters
226
+ pseudogradient = p .data - self ._backup_parameters [name ]
227
+ p .grad = pseudogradient
228
+
229
+ self ._average_grads ()
230
+
231
+ # Restore the parameters back to the previous state
232
+ self ._restore_parameters ()
233
+
234
+ # Use the outer optimizer to update the model parameters
235
+ self ._outer_optimizer .step ()
236
+ self ._outer_optimizer .zero_grad ()
237
+
238
+ def _average_grads (self ) -> None :
239
+ """
240
+ Average the gradients across the diloco group.
241
+ """
242
+ works = []
243
+ for p in self ._model .parameters ():
244
+ # Perform allreduce on the pseudogradients
245
+ assert p .grad is not None
246
+ work = self ._manager .allreduce (p .grad )
247
+ works .append (work )
248
+ # Wait for all allreduce operations to complete
249
+ for work in works :
250
+ work .wait ()
0 commit comments