Skip to content

Fix gradient scaling bug with batch size#120

Merged
luciaquirke merged 3 commits intomainfrom
fix-batch-size-gradient-scaling
Jan 9, 2026
Merged

Fix gradient scaling bug with batch size#120
luciaquirke merged 3 commits intomainfrom
fix-batch-size-gradient-scaling

Conversation

@claude
Copy link

@claude claude bot commented Jan 9, 2026

Summary

Fixes gradient scaling bug where gradient magnitudes varied based on batch size.

The Problem: When using losses.mean().backward(), the gradients were being scaled by 1/batch_size. This meant that processing the same data in different batch configurations would produce gradients with different magnitudes, leading to inconsistent gradient scales (issue #112).

The Fix: Change to losses.sum().backward() so that gradients are invariant to batch size. The gradient magnitude now depends only on the data, not on how it's batched.

Changes

Test Plan

The new test test_batch_size_invariance.py verifies that:

  1. Gradient standard deviations are consistent (within 20% tolerance) when processing data separately vs together
  2. Gradients point in the same direction (high cosine similarity)

This test is based on David's replication script from the data-args branch that demonstrated the original bug.

Fixes #112


🤖 Generated with Claude Code

Change loss reduction from mean to sum in backward pass to make
gradient scales invariant to batch size. This ensures that gradients
have consistent magnitudes regardless of how the data is batched.

Add test based on issue #112 replication script to verify that
gradient scales are now consistent when processing data separately
vs together.

Fixes #112

Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
@claude claude bot mentioned this pull request Jan 9, 2026
claude bot pushed a commit that referenced this pull request Jan 9, 2026
Remove the losses.sum() change since it was split into standalone PR #120.
This PR should focus only on dtype utilities.

Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
@claude
Copy link
Author

claude bot commented Jan 9, 2026


Code review

No issues found. Checked for bugs and CLAUDE.md compliance.


@claude
Copy link
Author

claude bot commented Jan 9, 2026

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

@luciaquirke luciaquirke force-pushed the fix-batch-size-gradient-scaling branch from 13d720f to f908d2c Compare January 9, 2026 12:34
@luciaquirke luciaquirke self-requested a review January 9, 2026 12:34
@luciaquirke luciaquirke force-pushed the fix-batch-size-gradient-scaling branch from f908d2c to 95c057d Compare January 9, 2026 12:38
Copy link
Collaborator

@luciaquirke luciaquirke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job claude

@luciaquirke luciaquirke force-pushed the fix-batch-size-gradient-scaling branch from b81b919 to fd6709f Compare January 9, 2026 12:52
@luciaquirke luciaquirke force-pushed the fix-batch-size-gradient-scaling branch from fd6709f to 5e5751d Compare January 9, 2026 12:53
@luciaquirke luciaquirke merged commit 689f17b into main Jan 9, 2026
6 checks passed
luciaquirke added a commit that referenced this pull request Jan 11, 2026
Remove the losses.sum() change since it was split into standalone PR #120.
This PR should focus only on dtype utilities.

Co-authored-by: Lucia Quirke <luciaquirke@users.noreply.github.com>
@luciaquirke luciaquirke deleted the fix-batch-size-gradient-scaling branch January 13, 2026 23:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Numerical instability: outlier gradients leading to non-reproducible results

1 participant