Skip to content

Fix per-example loss scaling for mixed-length batches#572

Open
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-517-loss-scaling
Open

Fix per-example loss scaling for mixed-length batches#572
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-517-loss-scaling

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

  • Keep each AlphaFold loss component as a per-example tensor inside AlphaFoldLoss until final aggregation.
  • Apply sqrt(min(seq_length, crop_len)) independently per local batch example before taking the final mean.
  • Add focused regression coverage for per-example reduction paths and mixed-length aggregation.

Root Cause

Issue #517 points out that the loss scale was computed from the average sequence length of the local batch, then applied after the component losses had already been averaged. For local batches containing examples with different sequence lengths, this gives the wrong scaling factor for every example except by coincidence.

Changes

  • Added lightweight loss reduction helpers so existing public loss functions still default to scalar means while AlphaFoldLoss can request reduction="none".
  • Updated distogram, experimentally resolved, FAPE/backbone, pLDDT, masked MSA, supervised chi, TM, and violation losses to preserve per-example values when requested.
  • Changed AlphaFoldLoss to validate per-example component shapes, aggregate weighted per-example losses, apply per-example sequence-length scaling, and then mean over the local batch.
  • Fixed violation clash normalization to use per-example atom counts instead of one local-batch-wide atom count.
  • Added tests for the new reduction contract and for the exact mixed-length scaling regression.

Fixes #517.

Validation

  • python -m py_compile openfold/utils/loss.py tests/test_loss.py
  • git diff --check
  • Dependency-light smoke validation covering the changed helpers, component reduction="none" paths, violation atom normalization, and full AlphaFoldLoss mixed-length aggregation.

Notes

The full targeted pytest command could not run in this local checkout because the available Python environment is incomplete/broken: uv run pytest ... fails on invalid local Anaconda llvmlite egg metadata, and python -m pytest ... fails during collection because ml_collections is missing.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Shouldn't loss for each example be weighed by the sqrt of the length?

1 participant