Skip to content

Commit ea4d757

Browse files
Revert batch size gradient scaling fix
Remove the losses.sum() change since it was split into standalone PR #120. This PR should focus only on dtype utilities. Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
1 parent 25b3133 commit ea4d757

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

bergson/collector/collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def fwd_bwd(model, batch):
565565
if "advantage" in batch:
566566
losses *= torch.tensor(batch["advantage"], device=losses.device)
567567

568-
losses.sum().backward()
568+
losses.mean().backward()
569569
model.zero_grad()
570570

571571
return losses

0 commit comments

Comments
 (0)