Skip to content

Commit 9ff17b1

Browse files
committed
accuracy fix
1 parent 31cecf0 commit 9ff17b1

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

torchft/manager.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import uuid
3535
import weakref
3636
from concurrent.futures import ThreadPoolExecutor
37-
from contextlib import nullcontext
37+
from contextlib import contextmanager
3838
from datetime import timedelta
3939
from enum import Enum
4040
from typing import (
@@ -454,8 +454,11 @@ def allreduce(
454454

455455
# If dirty, the result will not be committed, so return empty tensor.
456456
if self._dataloader_dirty:
457-
work = _DummyWork(torch.zeros_like(tensor))
458-
return _ManagedWork(self, work, tensor)
457+
tensor.zero_()
458+
return _ManagedWork(self, _DummyWork(tensor), tensor)
459+
460+
if not self.require_backward_grad_sync:
461+
return _ManagedWork(self, _DummyWork(tensor), tensor)
459462

460463
num_participants: int = self.num_participants()
461464

@@ -496,7 +499,7 @@ def callback(
496499
) -> torch.Tensor:
497500
nonlocal tensor
498501
if reduce_op == ReduceOp.AVG:
499-
tensor /= num_participants
502+
tensor /= num_participants * self._accumulation_steps
500503
return tensor
501504

502505
managed_work = _ManagedWork(self, work, tensor)
@@ -513,6 +516,15 @@ def callback(
513516

514517
return _DummyWork(tensor)
515518

519+
@contextmanager
520+
def no_sync(self):
521+
old_require_backward_grad_sync = self.require_backward_grad_sync
522+
self.require_backward_grad_sync = False
523+
try:
524+
yield
525+
finally:
526+
self.require_backward_grad_sync = old_require_backward_grad_sync
527+
516528
def report_error(self, e: Exception) -> None:
517529
"""
518530
Report an error to the manager.
@@ -931,6 +943,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
931943
Raises:
932944
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
933945
"""
946+
# Sometime allreduce is not called before should_commit, we need to wait quorum
947+
self.wait_quorum()
948+
934949
# make sure recovery is complete before committing
935950
with torch.profiler.record_function(
936951
"torchft::manager::should_commmit::recovery_stream::synchronize"

train_ddp2.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def load_model(m, optimizer, manager):
7171
with open(f"{CHECKPOINT_PATH}_latest", "r") as f:
7272
latest_checkpoint_path = f.read().strip()
7373
print(f"Loading checkpoint from {latest_checkpoint_path}")
74-
loaded_state_dict = torch.load(latest_checkpoint_path)
74+
loaded_state_dict = torch.load(latest_checkpoint_path, weights_only=True)
7575
m.load_state_dict(loaded_state_dict["model"])
7676
optimizer.load_state_dict(loaded_state_dict["optim"])
7777
manager.load_state_dict(loaded_state_dict["torchft"])
@@ -89,10 +89,12 @@ def main() -> None:
8989
)
9090

9191
def load_state_dict(state_dict):
92+
print("Received checkpoint!")
9293
m.load_state_dict(state_dict["model"])
9394
optimizer.load_state_dict(state_dict["optim"])
9495

9596
def state_dict():
97+
print("Setup checkpoint to send!")
9698
return {
9799
"model": m.state_dict(),
98100
"optim": optimizer.state_dict(),
@@ -206,21 +208,29 @@ def forward(self, x):
206208
) is not None:
207209
optimizer.zero_grad()
208210
total_loss = 0.0
209-
for inputs, labels in batches:
211+
accumulation_steps = len(batches)
212+
for i in range(accumulation_steps):
213+
inputs, labels = batches[i]
210214
inputs = inputs.to(device)
211215
labels = labels.to(device)
212216
out = m(inputs)
213217
loss = criterion(out, labels)
214-
loss.backward()
218+
if i == accumulation_steps - 1:
219+
loss.backward()
220+
else:
221+
with manager.no_sync():
222+
loss.backward()
215223
total_loss += loss.item()
224+
216225
# If errored, the optimizer step will be a no-op, and the parameter will not be updated.
217226
# Although it is possible to use new pg to compute old batches, it is still safe.
218227
if not optimizer.step():
219228
continue
220229

221230
# all reduce the loss across all replicas
222-
total_loss /= len(batches)
231+
total_loss = total_loss / BATCH_SIZE
223232
loss_tensor = torch.tensor(total_loss, device=device)
233+
# manager all reduce will divide by replica world size * accumulation steps
224234
manager.allreduce(loss_tensor).wait()
225235
avg_loss = loss_tensor.item()
226236
if manager.participating_rank() == 0:

0 commit comments

Comments
 (0)