11
11
This module implements a fault tolerant version of LocalSGD and related methods.
12
12
"""
13
13
14
- from typing import Any , Dict , List , Mapping , Optional
14
+ import logging
15
+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional
15
16
16
17
import torch
17
18
from torch import nn , optim
18
19
20
+ from torch .nn .parameter import Parameter
21
+ from torch .optim .optimizer import Optimizer
22
+
19
23
from torchft .manager import Manager
20
24
25
+ logger : logging .Logger = logging .getLogger (__name__ )
26
+
21
27
22
- class LocalSGD ( nn . Module ) :
28
+ class LocalSGD :
23
29
"""
24
- LocalSGD is a model wrapper similar to DistributedDataParallel that
30
+ LocalSGD is a context manager that
25
31
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26
32
27
33
This will synchronize the model parameters periodically in a fault tolerant
@@ -71,8 +77,8 @@ def __init__(
71
77
72
78
self ._manager = manager
73
79
self ._model = model
80
+ self ._local_optimizer = optimizer
74
81
self ._local_step = 0
75
- self ._started_step = False
76
82
self ._sync_every = sync_every
77
83
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
78
84
@@ -93,7 +99,30 @@ def __init__(
93
99
# Need to copy the parameters to the host to be safe if we are on the first step.
94
100
self ._save_parameters ()
95
101
96
- optimizer .register_step_post_hook (self ._step_post_hook )
102
+ def __enter__ (self ):
103
+ # Add optimizer hook which increments the local step counter and syncs if necessary
104
+ self ._opt_hook = self ._local_optimizer .register_step_post_hook (
105
+ self ._step_post_hook
106
+ )
107
+
108
+ # Register a forward prehook to check for quorum
109
+ self ._forward_pre_hook = self ._model .register_forward_pre_hook (
110
+ self ._forward_step_pre_hook
111
+ )
112
+
113
+ return self
114
+
115
+ def __exit__ (self , exc_type , exc_value , traceback ):
116
+ # Handle any cleanup or error handling here
117
+ if exc_type is not None :
118
+ # If an exception occurred, restore parameters
119
+ self ._restore_parameters ()
120
+
121
+ # Clean up hooks
122
+ self ._opt_hook .remove ()
123
+ self ._forward_pre_hook .remove ()
124
+
125
+ return False # Propagate exceptions
97
126
98
127
def _save_parameters (self ) -> None :
99
128
# TODO: consider running copy on a separate stream
@@ -105,71 +134,53 @@ def _restore_parameters(self) -> None:
105
134
for name , p in self ._model .named_parameters ():
106
135
p .data .copy_ (self ._backup_parameters [name ], non_blocking = True )
107
136
108
- # pyre-fixme[14]: support state_dict args
109
- def state_dict (self ) -> Dict [str , object ]:
110
- """
111
- state_dict returns the state_dict from the last time LocalSGD
112
- synchronized and not the current weights.
113
- """
114
- state_dict = self ._model .state_dict ()
115
- for name , p in self ._backup_parameters .items ():
116
- assert name in state_dict
117
- state_dict [name ] = p
118
- return state_dict
119
-
120
- def load_state_dict (
121
- self , state_dict : Mapping [str , Any ], strict : bool = True , assign : bool = False
137
+ def _step_post_hook (
138
+ self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
122
139
) -> None :
123
140
"""
124
- Loads the state dict to the model and the backup parameters.
125
-
126
- This must be called while the model weights aren't being modified to
127
- avoid corrupting the backup weights.
141
+ This hook is registered on the optimizer and is called after the optimizer step.
128
142
"""
129
- self ._model .load_state_dict (state_dict , strict = strict , assign = assign )
130
- self ._save_parameters ()
143
+ self ._local_step += 1
144
+ if self ._local_step >= self ._sync_every :
145
+ self .sync ()
131
146
132
- def forward (self , * args : object , ** kwargs : object ) -> object :
147
+ def _forward_step_pre_hook (self , _module , _args ) :
133
148
"""
134
- Run the model parameters.
135
-
136
- This should be called before the optimizer step.
137
-
138
- This will start the quorum and save the parameters if this is the first step.
149
+ Start the quorum before each module forward.
139
150
"""
140
151
if self ._local_step == 0 :
141
152
self ._manager .start_quorum ()
142
153
143
- self ._started_step = True
144
-
145
- return self ._model .forward (* args , ** kwargs )
154
+ # def should_sync(self) -> bool:
155
+ # """
156
+ # Checks if the model should be synchronized.
157
+ # """
158
+ # if self._local_step >= self._sync_every:
159
+ # return True
160
+ # else:
161
+ # return False
146
162
147
- def _step_post_hook (
148
- self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
149
- ) -> None :
163
+ def sync (self ) -> None :
150
164
"""
151
- This hook is registered on the optimizer and is called after the optimizer step.
152
-
153
- This will call the allreduce on the model weights every sync_every steps.
154
- If any errors occur it will restore to the weights from the previous sync.
155
-
156
- ``forward`` must be called before this function.
165
+ Synchronizes and averages the model weights across the manager.
157
166
"""
158
- assert self ._started_step , "forward must be called before step"
159
- self ._started_step = False
160
-
161
- self ._local_step += 1
167
+ self ._local_step = 0
168
+ self ._perform_sync ()
162
169
163
- if self ._local_step >= self ._sync_every :
164
- self ._local_step = 0
165
- self ._average ()
170
+ if self ._manager .should_commit ():
171
+ # save the parameters so we can restore from them later if necessary.
172
+ self ._save_parameters ()
173
+ else :
174
+ # commit failed, restore from the backup parameters
175
+ self ._restore_parameters ()
166
176
167
- if self ._manager .should_commit ():
168
- # save the parameters so we can restore from them later if necessary.
169
- self ._save_parameters ()
170
- else :
171
- # commit failed, restore from the backup parameters
172
- self ._restore_parameters ()
177
+ def _perform_sync (self ) -> None :
178
+ """
179
+ Performs the synchronization of the model weights across the manager.
180
+ This method is intended to be overridden by subclasses to implement custom
181
+ synchronization logic.
182
+ """
183
+ self ._average ()
173
184
174
185
def _average (self ) -> None :
175
186
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +193,63 @@ def _average(self) -> None:
182
193
183
194
for work in works :
184
195
work .wait ()
196
+
197
+
198
+ class DiLoCo (LocalSGD ):
199
+ """
200
+ DiLoCo is a subclass of LocalSGD that overrides the synchronization
201
+ mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
202
+
203
+ diloco: https://arxiv.org/pdf/2311.08105
204
+ """
205
+
206
+ def __init__ (
207
+ self ,
208
+ manager : Manager ,
209
+ model : nn .Module ,
210
+ inner_optimizer : optim .Optimizer ,
211
+ outer_optimizer : optim .Optimizer ,
212
+ sync_every : int ,
213
+ backup_device : Optional [torch .device ] = None ,
214
+ pin_memory : bool = True ,
215
+ ) -> None :
216
+ super ().__init__ (
217
+ manager , model , inner_optimizer , sync_every , backup_device , pin_memory
218
+ )
219
+ self ._outer_optimizer = outer_optimizer
220
+
221
+ def _model_sync (self ) -> None :
222
+ """
223
+ ensure model has the same weights
224
+ """
225
+ pass
226
+
227
+ def _perform_sync (self ) -> None :
228
+ """
229
+ Overrides the sync method to calculate the pseugradient, average them across the manager group, and
230
+ step using the outer optimizer.
231
+ """
232
+
233
+ # Set the .grad field of each parameter to its pseudogradient
234
+ for name , p in self ._model .named_parameters ():
235
+ assert name in self ._backup_parameters
236
+ pseudogradient = p .data - self ._backup_parameters [name ]
237
+ p .grad = pseudogradient
238
+
239
+ self ._average_grads ()
240
+
241
+ # Use the outer optimizer to update the model parameters
242
+ self ._outer_optimizer .step ()
243
+
244
+ def _average_grads (self ) -> None :
245
+ """
246
+ Average the gradients across the diloco group.
247
+ """
248
+ works = []
249
+ for p in self ._model .parameters ():
250
+ # Perform allreduce on the pseudogradients
251
+ work = self ._manager .allreduce (p .grad )
252
+ works .append (work )
253
+ # Wait for all allreduce operations to complete
254
+ for work in works :
255
+ work .wait ()
0 commit comments