Skip to content

Commit 1ea376a

Browse files
committed
ProjectTo maps to projecting onto the inner val, cleaner NoTangents in * pullback
1 parent d4e0fc8 commit 1ea376a

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

src/chainrules.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,22 @@ function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U}
55
return unitful_x, uq_pullback
66
end
77

8-
function ProjectTo(x::Quantity)
9-
project_val = ProjectTo(x.val) # Project the literal number
10-
return ProjectTo{typeof(x)}(; project_val = project_val)
11-
end
12-
138
function (projector::ProjectTo{<:Quantity})(x::Number)
149
new_val = projector.project_val(ustrip(x))
1510
return new_val*unit(x)
1611
end
1712

1813
# Project Unitful Quantities onto numerical types by projecting the value and carrying units
14+
ProjectTo(x::Quantity) = ProjectTo(x.val)
15+
1916
(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx)
2017
(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx)
2118

2219
function rrule(::typeof(*), x::Quantity, y::Units, z::Units...)
2320
Ω = *(x, y, z...)
24-
project_x = ProjectTo(x)
2521
function times_pb(Δ)
26-
δ = project_x(Δ)
27-
units = (y, z...)
28-
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...)
22+
nots = ntuple(_ -> NoTangent(), 1 + length(z))
23+
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)
2924
end
3025
return Ω, times_pb
3126
end

0 commit comments

Comments
 (0)