3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
-
7
6
"""
8
7
LocalSGD
9
8
=========
10
-
11
9
This module implements a fault tolerant version of LocalSGD and related methods.
12
10
"""
13
-
14
- from typing import Any , Dict , List , Mapping , Optional
11
+ import logging
12
+ from types import TracebackType
13
+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
15
14
16
15
import torch
17
16
from torch import nn , optim
17
+ from torch .nn .parameter import Parameter
18
+ from torch .optim .optimizer import Optimizer
19
+ from torch .utils .hooks import RemovableHandle
18
20
19
21
from torchft .manager import Manager
20
22
23
+ logger : logging .Logger = logging .getLogger (__name__ )
24
+
21
25
22
- class LocalSGD ( nn . Module ) :
26
+ class LocalSGD :
23
27
"""
24
- LocalSGD is a model wrapper similar to DistributedDataParallel that
28
+ LocalSGD is a context manager that
25
29
implements the algorithm described in https://arxiv.org/pdf/1805.09767
26
30
27
31
This will synchronize the model parameters periodically in a fault tolerant
@@ -60,26 +64,22 @@ def __init__(
60
64
) -> None :
61
65
"""
62
66
Args:
63
- manager: The manager to use.
64
- model: The model to wrap.
65
- optimizer: The optimizer used by the model.
66
- sync_every: How often to sync the model weights.
67
- backup_device: The device to store the backup of the model parameters on. (default cpu)
68
- pin_memory: Whether to pin the memory used for the backup of the model parameters.
67
+ manager (Manager) : The manager to use.
68
+ model (nn.Module) : The model to wrap.
69
+ optimizer (optim.Optimizer) : The optimizer used by the model.
70
+ sync_every (int) : How often to sync the model weights.
71
+ backup_device (Optional[torch.device]) : The device to store the backup of the model parameters on. (default cpu)
72
+ pin_memory (bool) : Whether to pin the memory used for the backup of the model parameters.
69
73
"""
70
74
super ().__init__ ()
71
-
72
75
self ._manager = manager
73
76
self ._model = model
77
+ self ._local_optimizer = optimizer
74
78
self ._local_step = 0
75
- self ._started_step = False
76
79
self ._sync_every = sync_every
77
80
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
78
-
79
81
device = backup_device or torch .device ("cpu" )
80
-
81
82
self ._backup_parameters : Dict [str , torch .Tensor ] = {}
82
-
83
83
for name , p in self ._model .named_parameters ():
84
84
t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
85
85
if (
@@ -90,95 +90,150 @@ def __init__(
90
90
t = t .pin_memory ()
91
91
self ._backup_parameters [name ] = t
92
92
93
+ self ._hooks : List [RemovableHandle ] = []
93
94
# Need to copy the parameters to the host to be safe if we are on the first step.
94
95
self ._save_parameters ()
95
96
96
- optimizer .register_step_post_hook (self ._step_post_hook )
97
+ def __enter__ (self ) -> "LocalSGD" :
98
+ # Add optimizer hook which increments the local step counter and syncs if necessary
99
+ self ._hooks .append (
100
+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
101
+ )
102
+ return self
103
+
104
+ def __exit__ (
105
+ self ,
106
+ exc_type : Optional [Type [BaseException ]],
107
+ exc_value : Optional [BaseException ],
108
+ traceback : Optional [TracebackType ],
109
+ ) -> bool :
110
+ # Handle any cleanup or error handling here
111
+ if exc_type is not None :
112
+ # If an exception occurred, restore parameters
113
+ self ._restore_parameters ()
114
+ # Clean up hooks
115
+ for hook in self ._hooks :
116
+ hook .remove ()
117
+ self ._hooks .clear ()
118
+
119
+ return False # Propagate exceptions
97
120
98
121
def _save_parameters (self ) -> None :
99
- # TODO: consider running copy on a separate stream
100
- for name , p in self ._model .named_parameters ():
101
- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
122
+ with torch .no_grad ():
123
+ # TODO: consider running copy on a separate stream
124
+ for name , p in self ._model .named_parameters ():
125
+ self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
102
126
103
127
def _restore_parameters (self ) -> None :
104
- # TODO: consider running copy on a separate stream
105
- for name , p in self ._model .named_parameters ():
106
- p .data .copy_ (self ._backup_parameters [name ], non_blocking = True )
128
+ with torch .no_grad ():
129
+ # TODO: consider running copy on a separate stream
130
+ for name , p in self ._model .named_parameters ():
131
+ p .data .copy_ (self ._backup_parameters [name ], non_blocking = False )
107
132
108
- # pyre-fixme[14]: support state_dict args
109
- def state_dict (self ) -> Dict [str , object ]:
110
- """
111
- state_dict returns the state_dict from the last time LocalSGD
112
- synchronized and not the current weights.
113
- """
114
- state_dict = self ._model .state_dict ()
115
- for name , p in self ._backup_parameters .items ():
116
- assert name in state_dict
117
- state_dict [name ] = p
118
- return state_dict
119
-
120
- def load_state_dict (
121
- self , state_dict : Mapping [str , Any ], strict : bool = True , assign : bool = False
133
+ def _step_post_hook (
134
+ self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
122
135
) -> None :
123
136
"""
124
- Loads the state dict to the model and the backup parameters.
137
+ This hook is registered on the optimizer and is called after the optimizer step.
138
+ """
139
+ self ._local_step += 1
140
+ if self ._local_step >= self ._sync_every :
141
+ self .sync ()
125
142
126
- This must be called while the model weights aren't being modified to
127
- avoid corrupting the backup weights.
143
+ def sync (self ) -> None :
128
144
"""
129
- self ._model .load_state_dict (state_dict , strict = strict , assign = assign )
130
- self ._save_parameters ()
145
+ Synchronizes and averages the model weights across the manager.
146
+ """
147
+ self ._manager .start_quorum ()
148
+ self ._perform_sync ()
149
+ self ._local_step = 0
131
150
132
- def forward (self , * args : object , ** kwargs : object ) -> object :
151
+ def _perform_sync (self ) -> None :
152
+ """
153
+ Performs the synchronization of the model weights across the manager.
154
+ This method is intended to be overridden by subclasses to implement custom
155
+ synchronization logic.
133
156
"""
134
- Run the model parameters.
157
+ self ._average ()
158
+ if self ._manager .should_commit ():
159
+ self ._save_parameters ()
160
+ else :
161
+ # commit failed, restore from the backup parameters
162
+ self ._restore_parameters ()
135
163
136
- This should be called before the optimizer step.
164
+ def _average (self ) -> None :
165
+ # TODO: do we need to broadcast buffers like DDP does?
137
166
138
- This will start the quorum and save the parameters if this is the first step.
139
- """
140
- if self ._local_step == 0 :
141
- self ._manager .start_quorum ()
167
+ works = []
168
+
169
+ for p in self ._model .parameters ():
170
+ # TODO: bucketize parameters
171
+ works .append (self ._manager .allreduce (p .data .detach ()))
142
172
143
- self ._started_step = True
173
+ for work in works :
174
+ work .wait ()
144
175
145
- return self ._model .forward (* args , ** kwargs )
146
176
147
- def _step_post_hook (
148
- self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
149
- ) -> None :
150
- """
151
- This hook is registered on the optimizer and is called after the optimizer step.
177
+ class DiLoCo (LocalSGD ):
178
+ """
179
+ DiLoCo is a subclass of LocalSGD that overrides the synchronization
180
+ mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
152
181
153
- This will call the allreduce on the model weights every sync_every steps.
154
- If any errors occur it will restore to the weights from the previous sync.
182
+ diloco: https://arxiv.org/pdf/2311.08105
183
+ """
155
184
156
- ``forward`` must be called before this function.
185
+ def __init__ (
186
+ self ,
187
+ manager : Manager ,
188
+ model : nn .Module ,
189
+ inner_optimizer : optim .Optimizer ,
190
+ outer_optimizer : optim .Optimizer ,
191
+ sync_every : int ,
192
+ backup_device : Optional [torch .device ] = None ,
193
+ pin_memory : bool = True ,
194
+ ) -> None :
195
+ if manager ._use_async_quorum :
196
+ raise ValueError (
197
+ "Using DiLoCo require synchronous quorum to be enabled. "
198
+ "Ensure that the manager is initialized with use_async_quorum=False"
199
+ )
200
+ super ().__init__ (
201
+ manager , model , inner_optimizer , sync_every , backup_device , pin_memory
202
+ )
203
+ self ._outer_optimizer = outer_optimizer
204
+
205
+ def _perform_sync (self ) -> None :
206
+ """
207
+ Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208
+ step using the outer optimizer.
157
209
"""
158
- assert self ._started_step , "forward must be called before step"
159
- self ._started_step = False
160
210
161
- self ._local_step += 1
211
+ # Set the .grad field of each parameter to its pseudogradient
212
+ for name , p in self ._model .named_parameters ():
213
+ assert name in self ._backup_parameters
214
+ pseudogradient = p .data - self ._backup_parameters [name ]
215
+ p .grad = pseudogradient
162
216
163
- if self ._local_step >= self . _sync_every :
164
- self . _local_step = 0
165
- self ._average ()
217
+ self ._average_grads ()
218
+ # Restore the parameters back to the previous state
219
+ self ._restore_parameters ()
166
220
167
- if self ._manager .should_commit ():
168
- # save the parameters so we can restore from them later if necessary.
169
- self ._save_parameters ()
170
- else :
171
- # commit failed, restore from the backup parameters
172
- self ._restore_parameters ()
173
-
174
- def _average (self ) -> None :
175
- # TODO: do we need to broadcast buffers like DDP does?
221
+ if self ._manager .should_commit ():
222
+ # Use the outer optimizer to update the model parameters
223
+ self ._outer_optimizer .step ()
224
+ self ._save_parameters ()
225
+ self ._outer_optimizer .zero_grad ()
176
226
227
+ def _average_grads (self ) -> None :
228
+ """
229
+ Average the gradients across the diloco group.
230
+ """
177
231
works = []
178
-
179
232
for p in self ._model .parameters ():
180
- # TODO: bucketize parameters
181
- works .append (self ._manager .allreduce (p .data .detach ()))
182
-
233
+ # Perform allreduce on the pseudogradients
234
+ assert p .grad is not None
235
+ work = self ._manager .allreduce (p .grad )
236
+ works .append (work )
237
+ # Wait for all allreduce operations to complete
183
238
for work in works :
184
239
work .wait ()
0 commit comments