Skip to content

put ct avals on GradAccums (and UndefinedPrimals), not primal avals#38168

Open
mattjj wants to merge 1 commit into
jax-ml:mainfrom
mattjj:accum-ct-avals
Open

put ct avals on GradAccums (and UndefinedPrimals), not primal avals#38168
mattjj wants to merge 1 commit into
jax-ml:mainfrom
mattjj:accum-ct-avals

Conversation

@mattjj
Copy link
Copy Markdown
Collaborator

@mattjj mattjj commented Jun 3, 2026

No description provided.

@mattjj mattjj requested a review from yashk2810 June 3, 2026 20:48
@mattjj mattjj self-assigned this Jun 3, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors JAX's automatic differentiation system to store and propagate cotangent abstract values directly within accumulator classes (such as GradAccum and its subclasses) rather than calling .to_ct_aval() on-the-fly during transpose rules. This simplifies numerous transpose implementations across lax.py and slicing.py. The review feedback highlights three key issues: an incorrect double call to .to_ct_aval() in ad.py, an out-of-order check in _broadcast_in_dim_transpose_rule that could lead to AttributeErrors on literals, and the use of a generic Exception in check_accum which should be replaced with a more specific ValueError.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread jax/_src/interpreters/ad.py Outdated
Comment thread jax/_src/lax/lax.py Outdated
Comment on lines 6867 to 6870
if type(ct) is ad_util.Zero:
return [ad_util.Zero(ct_aval)]
return [ad_util.Zero(operand.aval)]
if not isinstance(operand, ad.UndefinedPrimal):
return [None] # transpose wrt literal
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If operand is not an UndefinedPrimal (e.g., a literal or constant), the transpose rule should return [None]. However, if ct is Zero, the current code returns [ad_util.Zero(operand.aval)] before checking if operand is an UndefinedPrimal. This can lead to incorrect behavior or AttributeError if the literal does not have an aval attribute. Swapping the order of the checks ensures that non-active inputs are correctly ignored first.

Suggested change
if type(ct) is ad_util.Zero:
return [ad_util.Zero(ct_aval)]
return [ad_util.Zero(operand.aval)]
if not isinstance(operand, ad.UndefinedPrimal):
return [None] # transpose wrt literal
if not isinstance(operand, ad.UndefinedPrimal):
return [None] # transpose wrt literal
if type(ct) is ad_util.Zero:
return [ad_util.Zero(operand.aval)]

Comment thread jax/_src/api.py
Comment thread jax/_src/interpreters/ad.py Outdated
@google-ml-butler google-ml-butler Bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 3, 2026
@mattjj mattjj force-pushed the accum-ct-avals branch 3 times, most recently from 655d8cb to 3a86287 Compare June 4, 2026 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants