Fix gradient scaling bug with batch size#120
Merged
luciaquirke merged 3 commits intomainfrom Jan 9, 2026
Merged
Conversation
Change loss reduction from mean to sum in backward pass to make gradient scales invariant to batch size. This ensures that gradients have consistent magnitudes regardless of how the data is batched. Add test based on issue #112 replication script to verify that gradient scales are now consistent when processing data separately vs together. Fixes #112 Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
Merged
claude bot
pushed a commit
that referenced
this pull request
Jan 9, 2026
Remove the losses.sum() change since it was split into standalone PR #120. This PR should focus only on dtype utilities. Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
Author
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance. |
Author
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance. |
13d720f to
f908d2c
Compare
for more information, see https://pre-commit.ci
f908d2c to
95c057d
Compare
b81b919 to
fd6709f
Compare
fd6709f to
5e5751d
Compare
luciaquirke
added a commit
that referenced
this pull request
Jan 11, 2026
Remove the losses.sum() change since it was split into standalone PR #120. This PR should focus only on dtype utilities. Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes gradient scaling bug where gradient magnitudes varied based on batch size.
The Problem: When using
losses.mean().backward(), the gradients were being scaled by1/batch_size. This meant that processing the same data in different batch configurations would produce gradients with different magnitudes, leading to inconsistent gradient scales (issue #112).The Fix: Change to
losses.sum().backward()so that gradients are invariant to batch size. The gradient magnitude now depends only on the data, not on how it's batched.Changes
losses.mean().backward()tolosses.sum().backward()Test Plan
The new test
test_batch_size_invariance.pyverifies that:This test is based on David's replication script from the data-args branch that demonstrated the original bug.
Fixes #112
🤖 Generated with Claude Code