put ct avals on GradAccums (and UndefinedPrimals), not primal avals#38168
put ct avals on GradAccums (and UndefinedPrimals), not primal avals#38168mattjj wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors JAX's automatic differentiation system to store and propagate cotangent abstract values directly within accumulator classes (such as GradAccum and its subclasses) rather than calling .to_ct_aval() on-the-fly during transpose rules. This simplifies numerous transpose implementations across lax.py and slicing.py. The review feedback highlights three key issues: an incorrect double call to .to_ct_aval() in ad.py, an out-of-order check in _broadcast_in_dim_transpose_rule that could lead to AttributeErrors on literals, and the use of a generic Exception in check_accum which should be replaced with a more specific ValueError.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if type(ct) is ad_util.Zero: | ||
| return [ad_util.Zero(ct_aval)] | ||
| return [ad_util.Zero(operand.aval)] | ||
| if not isinstance(operand, ad.UndefinedPrimal): | ||
| return [None] # transpose wrt literal |
There was a problem hiding this comment.
If operand is not an UndefinedPrimal (e.g., a literal or constant), the transpose rule should return [None]. However, if ct is Zero, the current code returns [ad_util.Zero(operand.aval)] before checking if operand is an UndefinedPrimal. This can lead to incorrect behavior or AttributeError if the literal does not have an aval attribute. Swapping the order of the checks ensures that non-active inputs are correctly ignored first.
| if type(ct) is ad_util.Zero: | |
| return [ad_util.Zero(ct_aval)] | |
| return [ad_util.Zero(operand.aval)] | |
| if not isinstance(operand, ad.UndefinedPrimal): | |
| return [None] # transpose wrt literal | |
| if not isinstance(operand, ad.UndefinedPrimal): | |
| return [None] # transpose wrt literal | |
| if type(ct) is ad_util.Zero: | |
| return [ad_util.Zero(operand.aval)] |
655d8cb to
3a86287
Compare
No description provided.