@@ -71,7 +71,7 @@ def load_model(m, optimizer, manager):
7171 with open (f"{ CHECKPOINT_PATH } _latest" , "r" ) as f :
7272 latest_checkpoint_path = f .read ().strip ()
7373 print (f"Loading checkpoint from { latest_checkpoint_path } " )
74- loaded_state_dict = torch .load (latest_checkpoint_path )
74+ loaded_state_dict = torch .load (latest_checkpoint_path , weights_only = True )
7575 m .load_state_dict (loaded_state_dict ["model" ])
7676 optimizer .load_state_dict (loaded_state_dict ["optim" ])
7777 manager .load_state_dict (loaded_state_dict ["torchft" ])
@@ -89,10 +89,12 @@ def main() -> None:
8989 )
9090
9191 def load_state_dict (state_dict ):
92+ print ("Received checkpoint!" )
9293 m .load_state_dict (state_dict ["model" ])
9394 optimizer .load_state_dict (state_dict ["optim" ])
9495
9596 def state_dict ():
97+ print ("Setup checkpoint to send!" )
9698 return {
9799 "model" : m .state_dict (),
98100 "optim" : optimizer .state_dict (),
@@ -206,21 +208,29 @@ def forward(self, x):
206208 ) is not None :
207209 optimizer .zero_grad ()
208210 total_loss = 0.0
209- for inputs , labels in batches :
211+ accumulation_steps = len (batches )
212+ for i in range (accumulation_steps ):
213+ inputs , labels = batches [i ]
210214 inputs = inputs .to (device )
211215 labels = labels .to (device )
212216 out = m (inputs )
213217 loss = criterion (out , labels )
214- loss .backward ()
218+ if i == accumulation_steps - 1 :
219+ loss .backward ()
220+ else :
221+ with manager .no_sync ():
222+ loss .backward ()
215223 total_loss += loss .item ()
224+
216225 # If errored, the optimizer step will be a no-op, and the parameter will not be updated.
217226 # Although it is possible to use new pg to compute old batches, it is still safe.
218227 if not optimizer .step ():
219228 continue
220229
221230 # all reduce the loss across all replicas
222- total_loss /= len ( batches )
231+ total_loss = total_loss / BATCH_SIZE
223232 loss_tensor = torch .tensor (total_loss , device = device )
233+ # manager all reduce will divide by replica world size * accumulation steps
224234 manager .allreduce (loss_tensor ).wait ()
225235 avg_loss = loss_tensor .item ()
226236 if manager .participating_rank () == 0 :
0 commit comments