diff --git a/Project.toml b/Project.toml index d6669729..7b769f1d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,16 @@ name = "Unitful" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.11.0" +version = "1.12.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] +ChainRulesCore = "1" ConstructionBase = "1" julia = "1" diff --git a/src/Unitful.jl b/src/Unitful.jl index dac82934..bf642c76 100644 --- a/src/Unitful.jl +++ b/src/Unitful.jl @@ -27,6 +27,8 @@ import Random import ConstructionBase: constructorof +import ChainRulesCore: rrule, NoTangent, ProjectTo + export logunit, unit, absoluteunit, dimension, uconvert, ustrip, upreferred export @dimension, @derived_dimension, @refunit, @unit, @affineunit, @u_str export Quantity, DimensionlessQuantity, NoUnits, NoDims @@ -69,5 +71,6 @@ include("logarithm.jl") include("complex.jl") include("pkgdefaults.jl") include("dates.jl") +include("chainrules.jl") end diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 00000000..996e5280 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,29 @@ +function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} + unitful_x = Quantity{T,D,U}(x) + projector_x = ProjectTo(x) + uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) + return unitful_x, uq_pullback +end + +function (projector::ProjectTo{<:Quantity})(x::Number) + new_val = projector.project_val(ustrip(x)) + return new_val*unit(x) +end + +# Project Unitful Quantities onto numerical types by projecting the value and carrying units +ProjectTo(x::Quantity) = ProjectTo(x.val) + +(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx) +(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx) + +function rrule(::typeof(*), x::Quantity, y::Units, z::Units...) + Ω = *(x, y, z...) + function times_pb(Δ) + nots = ntuple(_ -> NoTangent(), 1 + length(z)) + return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...) + end + return Ω, times_pb +end + +rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y)) +rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y)) diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 00000000..03e39ddd --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,52 @@ +using ChainRulesCore: rrule, ProjectTo, NoTangent + +@testset "ProjectTo" begin + real_test(proj, val) = proj(val) == real(val) + complex_test(proj, val) = proj(val) == val + uval = 8.0*u"W" + p_uval = ProjectTo(uval) + cuval = (1.0+im)*u"kg" + p_cuval = ProjectTo(cuval) + + p_real = ProjectTo(1.0) + p_complex = ProjectTo(1.0+im) + + δval = 6.0*u"m" + δcval = (2.0+3.0im)*u"L" + + # Test projection onto real unitful quantities + for δ in (δval, δcval, 1.0, 1.0+im) + @test real_test(p_uval, δ) + end + + # Test projection onto complex unitful quantities + for δ in (δval, δcval, 1.0, 1.0+im) + @test complex_test(p_cuval, δ) + end + + # Projecting Unitful quantities onto real values + @test p_real(δval) == δval + @test p_real(δcval) == real(δcval) + + # Projecting Unitful quantities onto complex values + @test p_complex(δval) == δval + @test p_complex(δcval) == δcval +end + +@testset "rrules" begin + @testset "Quantity rrule" begin + UT = typeof(1.0*u"W") + x = 5.0 + Ω, pb = rrule(UT, x) + @test Ω == 5.0 * u"W" + @test pb(3.0) == (NoTangent(), 3.0 * u"W") + end + @testset "* rrule" begin + x = 5.0*u"W" + y = u"m" + z = u"L" + Ω, pb = rrule(*, x, y, z) + @test Ω == x*y*z + @test pb(3.0) == (NoTangent(), 3.0*y*z, NoTangent(), NoTangent()) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b6481012..53803642 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2047,6 +2047,10 @@ end """ end +@testset "ChainRules" begin + include("./chainrules.jl") +end + # Test precompiled Unitful extension modules load_path = mktempdir() load_cache_path = mktempdir()