-
Notifications
You must be signed in to change notification settings - Fork 122
ChainRules rrule Integration for Unitful
#504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChainRules rrule Integration for Unitful
#504
Conversation
oxinabox
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me
src/chainrules.jl
Outdated
| function ProjectTo(x::Quantity) | ||
| project_val = ProjectTo(x.val) # Project the literal number | ||
| return ProjectTo{typeof(x)}(; project_val = project_val) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't really matter but the convention ChainRulesCore uses is to match the field name, if in doubt.c
| function ProjectTo(x::Quantity) | |
| project_val = ProjectTo(x.val) # Project the literal number | |
| return ProjectTo{typeof(x)}(; project_val = project_val) | |
| end | |
| function ProjectTo(x::Quantity) | |
| val = ProjectTo(x.val) # Project the literal number | |
| return ProjectTo{typeof(x)}(; val = val) | |
| end |
| function (projector::ProjectTo{<:Quantity})(x::Number) | ||
| new_val = projector.project_val(ustrip(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convention:
| function (projector::ProjectTo{<:Quantity})(x::Number) | |
| new_val = projector.project_val(ustrip(x)) | |
| function (project::ProjectTo{<:Quantity})(x::Number) | |
| new_val = project.val(ustrip(x)) |
src/chainrules.jl
Outdated
|
|
||
| function ProjectTo(x::Quantity) | ||
| project_val = ProjectTo(x.val) # Project the literal number | ||
| return ProjectTo{typeof(x)}(; project_val = project_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This stores the complete type, and hence unit, of x in the type of the projector. But when this is applied, you use only unit(dx) and not this unit. That's mathematically correct, I think, since the gradient will typically have different units. But it also means that storing this is redundant.
It could just be ProjectTo{Quantity} -- many of them store just the top-level type. But could it just be ProjectTo(x.val)?
Right now ProjectTo{Float64} will allow through any dx with units, without changing them. To get the present behaviour of this PR, could you just define methods for (::ProjectTo{<:Number})(dx::Quantity) which un-wrap, adjust the precision etc if necessary, and re-wrap? Or does this land you in dispatch ambiguity hell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh true, i missed that, my bad
src/chainrules.jl
Outdated
| units = (y, z...) | ||
| return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
| units = (y, z...) | |
| return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) | |
| return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), 1+length(z))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was going to suggest something like:
nots = ntuple(Returns(NoTangent()), 1 + length(z))
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)
since I think there is little to gain by making the pullback close over the ProjectTo instead of over x. But something to gain in readability by needing fewer symbols. (But this is just style.)
| function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} | ||
| unitful_x = Quantity{T,D,U}(x) | ||
| projector_x = ProjectTo(x) | ||
| uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) | ||
| return unitful_x, uq_pullback | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this called?
If't it's used when attaching units to an initially plain number, x=1 -> unitful_x = 1m, then the thinking is that if the loss is a unitless scalar, the gradient for unitful_x will be d loss / d unitful_x = 100/m, and this will produce a gradient for x with no units (or units equivalent to 1)?
And does that work out in practice? With some Zygote.gradient(loss, 1u"m")... must you ensure by hand that you remove the units within loss, or does Zygote.sensititvity do the right thing? Maybe that's a bigger question than this function... have not thought much about how this all ought to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zygote.sensitivity returns the multiplicative identity which is usually 1.0, even for Unitful.Quantity. The example worked out in @oxinabox's comment matches how I've thought about this, so I think Zygote is correct here.
|
I am trying to think this one though. Consider: function f(t)
a = 5 Meter/Second
x = a*t
return x
endso I initially assumed (incorrectly) that the seed co-tangent has the same units as a difference of primals, which is same as the primal units in most (all?) cases Then we would get So I guess the seed must have units of
Which was what was wanted. |
|
It's above my abilities to review this but I just wanted to say that this would be a great addition to AD in Julia. |
fbe6810 to
1ea376a
Compare
Codecov Report
@@ Coverage Diff @@
## master #504 +/- ##
==========================================
+ Coverage 84.94% 88.00% +3.05%
==========================================
Files 16 17 +1
Lines 1448 1467 +19
==========================================
+ Hits 1230 1291 +61
+ Misses 218 176 -42
Continue to review full report at Codecov.
|
|
Are there any further steps required before this can get merged? Should there be a manual rrule implemented for the tests, maybe? |
|
Added some tests for the I left the
|
|
Is there anything else that needs doing or can this be merged? |
|
This only implements the Relatedly, @oxinabox do you think we need the |
rrule Integration for Unitful
add endline Co-authored-by: Mosè Giordano <[email protected]>
bd88a9d to
5e24deb
Compare
|
Accidentally bumped the patch version before; should be good now |
|
This is overall ok with me, but the question is do we want to add another dependency? This package has been traditionally rather conservative on taking on dependencies, however there is already a non-standard library and @ajkeller34 @sostock opinions? |
|
I am awfully new to Julia and SciML ecosystem. Would this addition make possible to run |
It should now be even lighter |
|
Is there anything left to do on this PR? |
|
Closing this PR. I turned this into its own package: https://github.com/SBuercklin/UnitfulChainRules.jl I just submitted the registration on the General registry, once that clears I'll submit a PR adding a link to the |
This intention of this PR is to implement the machinery within
Unitful.jlto allow for autodiff overUnitful.Quantitys. Specifically, it should include the necessaryChainRulesCore.jlmethods to provide some basic level of compatibility withChainRules-based AD systems.Before this PR,
Quantitys would be reduced toTangent{Any}(val = ...)which would break a lot of basic AD arithmetic. After this PR:I've implemented an
rrulefor theQuantityconstructor,ProjectTo{Quantity}, and arithmetic betweenNumber/QuantityandUnits, which are used to call theQuantityconstructor. Generally speaking, projecting aQuantityto aNumberinvolves projecting the value of theQuantityonto the number, and then propagating the units of the projectingQuantityonto theNumber. This ensures the proper real/complex behavior is obeyed while maintaining correct units.I wanted to get feedback before continuing to ensure that:
Testing is difficult as
ChainRulesTestUtils.jldoes not play nicely withUnitful.jlat the moment. I can manually test therrules andProjectTo, but if there are any other testing approaches I'm open to ideas.Regarding dependencies elsewhere, more full compatibility with
ChainRules.jlneeds this PR which relaxes the constraint over many of the rules fromUnion{Real, Complex}to justNumber. This should give compatibility withQuantitys which subtypeNumber, but typically wrap<:Real, <:Complex.Remaining work:
Quantitys andUnits (right now only*and/are implemented)frules