11
11
This module implements a fault tolerant version of LocalSGD and related methods.
12
12
"""
13
13
14
- from typing import Any , Dict , List , Mapping , Optional
14
+ import logging
15
+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional
15
16
16
17
import torch
17
18
from torch import nn , optim
19
+ from torch .nn .parameter import Parameter
20
+ from torch .optim .optimizer import Optimizer
18
21
19
22
from torchft .manager import Manager
20
23
24
+ logger : logging .Logger = logging .getLogger (__name__ )
21
25
22
- class LocalSGD (nn .Module ):
26
+
27
+ class LocalSGD :
23
28
"""
24
- LocalSGD is a model wrapper similar to DistributedDataParallel that
29
+ LocalSGD is a context manager that
25
30
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26
31
27
32
This will synchronize the model parameters periodically in a fault tolerant
@@ -71,8 +76,8 @@ def __init__(
71
76
72
77
self ._manager = manager
73
78
self ._model = model
79
+ self ._local_optimizer = optimizer
74
80
self ._local_step = 0
75
- self ._started_step = False
76
81
self ._sync_every = sync_every
77
82
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
78
83
@@ -93,7 +98,30 @@ def __init__(
93
98
# Need to copy the parameters to the host to be safe if we are on the first step.
94
99
self ._save_parameters ()
95
100
96
- optimizer .register_step_post_hook (self ._step_post_hook )
101
+ def __enter__ (self ):
102
+ # Add optimizer hook which increments the local step counter and syncs if necessary
103
+ self ._opt_hook = self ._local_optimizer .register_step_post_hook (
104
+ self ._step_post_hook
105
+ )
106
+
107
+ # Register a forward prehook to check for quorum
108
+ self ._forward_pre_hook = self ._model .register_forward_pre_hook (
109
+ self ._forward_step_pre_hook
110
+ )
111
+
112
+ return self
113
+
114
+ def __exit__ (self , exc_type , exc_value , traceback ):
115
+ # Handle any cleanup or error handling here
116
+ if exc_type is not None :
117
+ # If an exception occurred, restore parameters
118
+ self ._restore_parameters ()
119
+
120
+ # Clean up hooks
121
+ self ._opt_hook .remove ()
122
+ self ._forward_pre_hook .remove ()
123
+
124
+ return False # Propagate exceptions
97
125
98
126
def _save_parameters (self ) -> None :
99
127
# TODO: consider running copy on a separate stream
@@ -105,71 +133,53 @@ def _restore_parameters(self) -> None:
105
133
for name , p in self ._model .named_parameters ():
106
134
p .data .copy_ (self ._backup_parameters [name ], non_blocking = True )
107
135
108
- # pyre-fixme[14]: support state_dict args
109
- def state_dict (self ) -> Dict [str , object ]:
110
- """
111
- state_dict returns the state_dict from the last time LocalSGD
112
- synchronized and not the current weights.
113
- """
114
- state_dict = self ._model .state_dict ()
115
- for name , p in self ._backup_parameters .items ():
116
- assert name in state_dict
117
- state_dict [name ] = p
118
- return state_dict
119
-
120
- def load_state_dict (
121
- self , state_dict : Mapping [str , Any ], strict : bool = True , assign : bool = False
136
+ def _step_post_hook (
137
+ self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
122
138
) -> None :
123
139
"""
124
- Loads the state dict to the model and the backup parameters.
125
-
126
- This must be called while the model weights aren't being modified to
127
- avoid corrupting the backup weights.
140
+ This hook is registered on the optimizer and is called after the optimizer step.
128
141
"""
129
- self ._model .load_state_dict (state_dict , strict = strict , assign = assign )
130
- self ._save_parameters ()
142
+ self ._local_step += 1
143
+ if self ._local_step >= self ._sync_every :
144
+ self .sync ()
131
145
132
- def forward (self , * args : object , ** kwargs : object ) -> object :
146
+ def _forward_step_pre_hook (self , _module , _args ) :
133
147
"""
134
- Run the model parameters.
135
-
136
- This should be called before the optimizer step.
137
-
138
- This will start the quorum and save the parameters if this is the first step.
148
+ Start the quorum before each module forward.
139
149
"""
140
150
if self ._local_step == 0 :
141
151
self ._manager .start_quorum ()
142
152
143
- self ._started_step = True
153
+ # def should_sync(self) -> bool:
154
+ # """
155
+ # Checks if the model should be synchronized.
156
+ # """
157
+ # if self._local_step >= self._sync_every:
158
+ # return True
159
+ # else:
160
+ # return False
144
161
145
- return self ._model .forward (* args , ** kwargs )
146
-
147
- def _step_post_hook (
148
- self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
149
- ) -> None :
162
+ def sync (self ) -> None :
150
163
"""
151
- This hook is registered on the optimizer and is called after the optimizer step.
152
-
153
- This will call the allreduce on the model weights every sync_every steps.
154
- If any errors occur it will restore to the weights from the previous sync.
155
-
156
- ``forward`` must be called before this function.
164
+ Synchronizes and averages the model weights across the manager.
157
165
"""
158
- assert self ._started_step , "forward must be called before step"
159
- self ._started_step = False
166
+ self ._local_step = 0
167
+ self ._perform_sync ()
160
168
161
- self ._local_step += 1
169
+ if self ._manager .should_commit ():
170
+ # save the parameters so we can restore from them later if necessary.
171
+ self ._save_parameters ()
172
+ else :
173
+ # commit failed, restore from the backup parameters
174
+ self ._restore_parameters ()
162
175
163
- if self ._local_step >= self ._sync_every :
164
- self ._local_step = 0
165
- self ._average ()
166
-
167
- if self ._manager .should_commit ():
168
- # save the parameters so we can restore from them later if necessary.
169
- self ._save_parameters ()
170
- else :
171
- # commit failed, restore from the backup parameters
172
- self ._restore_parameters ()
176
+ def _perform_sync (self ) -> None :
177
+ """
178
+ Performs the synchronization of the model weights across the manager.
179
+ This method is intended to be overridden by subclasses to implement custom
180
+ synchronization logic.
181
+ """
182
+ self ._average ()
173
183
174
184
def _average (self ) -> None :
175
185
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +192,63 @@ def _average(self) -> None:
182
192
183
193
for work in works :
184
194
work .wait ()
195
+
196
+
197
+ class DiLoCo (LocalSGD ):
198
+ """
199
+ DiLoCo is a subclass of LocalSGD that overrides the synchronization
200
+ mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
201
+
202
+ diloco: https://arxiv.org/pdf/2311.08105
203
+ """
204
+
205
+ def __init__ (
206
+ self ,
207
+ manager : Manager ,
208
+ model : nn .Module ,
209
+ inner_optimizer : optim .Optimizer ,
210
+ outer_optimizer : optim .Optimizer ,
211
+ sync_every : int ,
212
+ backup_device : Optional [torch .device ] = None ,
213
+ pin_memory : bool = True ,
214
+ ) -> None :
215
+ super ().__init__ (
216
+ manager , model , inner_optimizer , sync_every , backup_device , pin_memory
217
+ )
218
+ self ._outer_optimizer = outer_optimizer
219
+
220
+ def _model_sync (self ) -> None :
221
+ """
222
+ ensure model has the same weights
223
+ """
224
+ pass
225
+
226
+ def _perform_sync (self ) -> None :
227
+ """
228
+ Overrides the sync method to calculate the pseugradient, average them across the manager group, and
229
+ step using the outer optimizer.
230
+ """
231
+
232
+ # Set the .grad field of each parameter to its pseudogradient
233
+ for name , p in self ._model .named_parameters ():
234
+ assert name in self ._backup_parameters
235
+ pseudogradient = p .data - self ._backup_parameters [name ]
236
+ p .grad = pseudogradient
237
+
238
+ self ._average_grads ()
239
+
240
+ # Use the outer optimizer to update the model parameters
241
+ self ._outer_optimizer .step ()
242
+
243
+ def _average_grads (self ) -> None :
244
+ """
245
+ Average the gradients across the diloco group.
246
+ """
247
+ works = []
248
+ for p in self ._model .parameters ():
249
+ # Perform allreduce on the pseudogradients
250
+ work = self ._manager .allreduce (p .grad )
251
+ works .append (work )
252
+ # Wait for all allreduce operations to complete
253
+ for work in works :
254
+ work .wait ()
0 commit comments