|
| 1 | +function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} |
| 2 | + unitful_x = Quantity{T,D,U}(x) |
| 3 | + projector_x = ProjectTo(x) |
| 4 | + uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) |
| 5 | + return unitful_x, uq_pullback |
| 6 | +end |
| 7 | + |
| 8 | +function ProjectTo(x::Quantity) |
| 9 | + project_val = ProjectTo(x.val) # Project the literal number |
| 10 | + return ProjectTo{typeof(x)}(; project_val = project_val) |
| 11 | +end |
| 12 | + |
| 13 | +function (projector::ProjectTo{<:Quantity})(x::Number) |
| 14 | + new_val = projector.project_val(ustrip(x)) |
| 15 | + return new_val*x |
| 16 | +end |
| 17 | + |
| 18 | +# Project Unitful Quantities onto numerical types by projecting the value and carrying units |
| 19 | +(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx) |
| 20 | +(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx) |
| 21 | + |
| 22 | +function rrule(::typeof(*), x::Quantity, y::Units, z::Units...) |
| 23 | + Ω = *(x, y, z...) |
| 24 | + project_x = ProjectTo(x) |
| 25 | + function times_pb(Δ) |
| 26 | + δ = project_x(Δ) |
| 27 | + units = (y, z...) |
| 28 | + return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) |
| 29 | + end |
| 30 | + return Ω, times_pb |
| 31 | +end |
| 32 | + |
| 33 | +rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y)) |
| 34 | +rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y)) |
0 commit comments