10
10
"""
11
11
import logging
12
12
from types import TracebackType
13
- from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
13
+ from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
14
14
15
15
import torch
16
16
from torch import nn , optim
@@ -59,8 +59,6 @@ def __init__(
59
59
model : nn .Module ,
60
60
optimizer : optim .Optimizer ,
61
61
sync_every : int ,
62
- backup_device : Optional [torch .device ] = None ,
63
- pin_memory : bool = True ,
64
62
) -> None :
65
63
"""
66
64
Args:
@@ -78,21 +76,8 @@ def __init__(
78
76
self ._local_step = 0
79
77
self ._sync_every = sync_every
80
78
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
81
- device = backup_device or torch .device ("cpu" )
82
- self ._backup_parameters : Dict [str , torch .Tensor ] = {}
83
- for name , p in self ._model .named_parameters ():
84
- t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
85
- if (
86
- pin_memory
87
- and t .device == torch .device ("cpu" )
88
- and torch .cuda .is_available ()
89
- ):
90
- t = t .pin_memory ()
91
- self ._backup_parameters [name ] = t
92
79
93
80
self ._hooks : List [RemovableHandle ] = []
94
- # Need to copy the parameters to the host to be safe if we are on the first step.
95
- self ._save_parameters ()
96
81
97
82
def __enter__ (self ) -> "LocalSGD" :
98
83
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -108,37 +93,26 @@ def __exit__(
108
93
traceback : Optional [TracebackType ],
109
94
) -> bool :
110
95
# Handle any cleanup or error handling here
111
- if exc_type is not None :
112
- # If an exception occurred, restore parameters
113
- self ._restore_parameters ()
114
96
# Clean up hooks
115
97
for hook in self ._hooks :
116
98
hook .remove ()
117
99
self ._hooks .clear ()
118
100
119
101
return False # Propagate exceptions
120
102
121
- def _save_parameters (self ) -> None :
122
- with torch .no_grad ():
123
- # TODO: consider running copy on a separate stream
124
- for name , p in self ._model .named_parameters ():
125
- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
126
-
127
- def _restore_parameters (self ) -> None :
128
- with torch .no_grad ():
129
- # TODO: consider running copy on a separate stream
130
- for name , p in self ._model .named_parameters ():
131
- p .data .copy_ (self ._backup_parameters [name ], non_blocking = False )
132
-
133
103
def _step_post_hook (
134
- self , _optim : optim .Optimizer , _args : List [ object ], _kwargs : Dict [str , object ]
104
+ self , _optim : optim .Optimizer , _args : Tuple [ Any , ... ], _kwargs : Dict [str , Any ]
135
105
) -> None :
136
106
"""
137
107
This hook is registered on the optimizer and is called after the optimizer step.
138
108
"""
139
- self ._local_step += 1
140
- if self ._local_step >= self ._sync_every :
141
- self .sync ()
109
+ try :
110
+ self ._local_step += 1
111
+ if self ._local_step >= self ._sync_every :
112
+ self .sync ()
113
+ except Exception as e :
114
+ self ._manager .report_error (e )
115
+ raise
142
116
143
117
def sync (self ) -> None :
144
118
"""
@@ -151,15 +125,9 @@ def sync(self) -> None:
151
125
def _perform_sync (self ) -> None :
152
126
"""
153
127
Performs the synchronization of the model weights across the manager.
154
- This method is intended to be overridden by subclasses to implement custom
155
- synchronization logic.
156
128
"""
157
- self ._average ()
158
129
if self ._manager .should_commit ():
159
- self ._save_parameters ()
160
- else :
161
- # commit failed, restore from the backup parameters
162
- self ._restore_parameters ()
130
+ self ._average ()
163
131
164
132
def _average (self ) -> None :
165
133
# TODO: do we need to broadcast buffers like DDP does?
@@ -174,7 +142,7 @@ def _average(self) -> None:
174
142
work .wait ()
175
143
176
144
177
- class DiLoCo ( LocalSGD ) :
145
+ class DiLoCo :
178
146
"""
179
147
DiLoCo is a subclass of LocalSGD that overrides the synchronization
180
148
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
@@ -197,21 +165,96 @@ def __init__(
197
165
"Using DiLoCo require synchronous quorum to be enabled. "
198
166
"Ensure that the manager is initialized with use_async_quorum=False"
199
167
)
200
- super ().__init__ (
201
- manager , model , inner_optimizer , sync_every , backup_device , pin_memory
202
- )
168
+ super ().__init__ ()
169
+ self ._manager = manager
170
+ self ._model = model
171
+ self ._local_optimizer = inner_optimizer
172
+ self ._local_step = 0
173
+ self ._sync_every = sync_every
174
+ assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
175
+
176
+ self ._hooks : List [RemovableHandle ] = []
203
177
self ._outer_optimizer = outer_optimizer
178
+ self ._original_parameters : Dict [str , torch .Tensor ] = {}
179
+ for name , p in self ._model .named_parameters ():
180
+ t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = backup_device )
181
+ if (
182
+ pin_memory
183
+ and t .device == torch .device ("cpu" )
184
+ and torch .cuda .is_available ()
185
+ ):
186
+ t = t .pin_memory ()
187
+ self ._original_parameters [name ] = t
188
+
189
+ # Need to copy the parameters to the host to be safe if we are on the first step.
190
+ self ._save_parameters ()
191
+
192
+ def _save_parameters (self ) -> None :
193
+ with torch .no_grad ():
194
+ # TODO: consider running copy on a separate stream
195
+ for name , p in self ._model .named_parameters ():
196
+ self ._original_parameters [name ].copy_ (p .data , non_blocking = True )
197
+
198
+ def _restore_parameters (self ) -> None :
199
+ with torch .no_grad ():
200
+ # TODO: consider running copy on a separate stream
201
+ for name , p in self ._model .named_parameters ():
202
+ p .data .copy_ (self ._original_parameters [name ], non_blocking = False )
203
+
204
+ def __enter__ (self ) -> "DiLoCo" :
205
+ # Add optimizer hook which increments the local step counter and syncs if necessary
206
+ self ._hooks .append (
207
+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
208
+ )
209
+ return self
210
+
211
+ def __exit__ (
212
+ self ,
213
+ exc_type : Optional [Type [BaseException ]],
214
+ exc_value : Optional [BaseException ],
215
+ traceback : Optional [TracebackType ],
216
+ ) -> bool :
217
+ # Handle any cleanup or error handling here
218
+ # Clean up hooks
219
+ for hook in self ._hooks :
220
+ hook .remove ()
221
+ self ._hooks .clear ()
222
+
223
+ return False # Propagate exceptions
224
+
225
+ def _step_post_hook (
226
+ self , _optim : optim .Optimizer , _args : Tuple [Any , ...], _kwargs : Dict [str , Any ]
227
+ ) -> None :
228
+ """
229
+ This hook is registered on the optimizer and is called after the optimizer step.
230
+ """
231
+ try :
232
+ self ._local_step += 1
233
+ if self ._local_step >= self ._sync_every :
234
+ self .sync ()
235
+ except Exception as e :
236
+ self ._manager .report_error (e )
237
+ raise
238
+
239
+ def sync (self ) -> None :
240
+ """
241
+ Synchronizes and averages the model weights across the manager.
242
+ """
243
+ self ._manager .start_quorum ()
244
+ self ._perform_sync ()
245
+ self ._local_step = 0
204
246
205
247
def _perform_sync (self ) -> None :
206
248
"""
207
249
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208
250
step using the outer optimizer.
209
251
"""
252
+ print ("Performing DiLoCo sync" , flush = True )
210
253
211
254
# Set the .grad field of each parameter to its pseudogradient
212
255
for name , p in self ._model .named_parameters ():
213
- assert name in self ._backup_parameters
214
- pseudogradient = p .data - self ._backup_parameters [name ]
256
+ assert name in self ._original_parameters
257
+ pseudogradient = p .data - self ._original_parameters [name ]
215
258
p .grad = pseudogradient
216
259
217
260
self ._average_grads ()
0 commit comments