Skip to content

Commit 4dfdd69

Browse files
committed
added basic rrules, ProjectTo for Quantity
1 parent 2a3308e commit 4dfdd69

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
33
version = "1.9.2"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
78
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/Unitful.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import Random
2727

2828
import ConstructionBase: constructorof
2929

30+
import ChainRulesCore: rrule, NoTangent, ProjectTo
31+
3032
export logunit, unit, absoluteunit, dimension, uconvert, ustrip, upreferred
3133
export @dimension, @derived_dimension, @refunit, @unit, @affineunit, @u_str
3234
export Quantity, DimensionlessQuantity, NoUnits, NoDims
@@ -69,5 +71,6 @@ include("logarithm.jl")
6971
include("complex.jl")
7072
include("pkgdefaults.jl")
7173
include("dates.jl")
74+
include("chainrules.jl")
7275

7376
end

src/chainrules.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U}
2+
unitful_x = Quantity{T,D,U}(x)
3+
projector_x = ProjectTo(x)
4+
uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT))
5+
return unitful_x, uq_pullback
6+
end
7+
8+
function ProjectTo(x::Quantity)
9+
project_val = ProjectTo(x.val) # Project the literal number
10+
return ProjectTo{typeof(x)}(; project_val = project_val)
11+
end
12+
13+
function (projector::ProjectTo{<:Quantity})(x::Number)
14+
new_val = projector.project_val(ustrip(x))
15+
return new_val*x
16+
end
17+
18+
# Project Unitful Quantities onto numerical types by projecting the value and carrying units
19+
(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx)
20+
(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx)
21+
22+
function rrule(::typeof(*), x::Quantity, y::Units, z::Units...)
23+
Ω = *(x, y, z...)
24+
project_x = ProjectTo(x)
25+
function times_pb(Δ)
26+
δ = project_x(Δ)
27+
units = (y, z...)
28+
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...)
29+
end
30+
return Ω, times_pb
31+
end
32+
33+
rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y))
34+
rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y))

0 commit comments

Comments
 (0)