diff --git a/Project.toml b/Project.toml index a6a0657..c12d51b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,17 +12,32 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +ReferenceFrameRotationsZygoteExt = ["ForwardDiff", "Zygote"] + [compat] Crayons = "4.0" +DifferentiationInterface = "0.6" +FiniteDiff = "2.26" +ForwardDiff = "0.10" LinearAlgebra = "1.6" +Mooncake = "0.4" Printf = "1.6" Random = "1.6" StaticArrays = "1" Test = "1.6" -julia = "1.6" +Zygote = "0.6" +julia = "1.10, 1.11" [extras] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "DifferentiationInterface", "FiniteDiff", "Mooncake", "Zygote"] diff --git a/ext/ReferenceFrameRotationsZygoteExt.jl b/ext/ReferenceFrameRotationsZygoteExt.jl new file mode 100644 index 0000000..3066e5a --- /dev/null +++ b/ext/ReferenceFrameRotationsZygoteExt.jl @@ -0,0 +1,38 @@ +module ReferenceFrameRotationsZygoteExt + +using ReferenceFrameRotations +using ForwardDiff + +using Zygote.ChainRulesCore: ChainRulesCore +import Zygote.ChainRulesCore: NoTangent + +function ChainRulesCore.rrule( + ::Type{<:DCM}, data::NTuple{9, T} +) where {T} + + y = DCM(data) + + function DCM_pullback(Δ) + return (NoTangent(), Tuple(Δ)) + end + + return y, DCM_pullback +end + +function ChainRulesCore.rrule( + ::typeof(orthonormalize), dcm::DCM +) + + y = orthonormalize(dcm) + + function orthonormalize_pullback(Δ) + + jac = ForwardDiff.jacobian(orthonormalize, dcm) + + return (NoTangent(), reshape(vcat(Δ...)' * jac, 3, 3)) + end + + return y, orthonormalize_pullback +end + +end \ No newline at end of file diff --git a/src/conversions/angle_to_dcm.jl b/src/conversions/angle_to_dcm.jl index e8b2296..ca0e77a 100644 --- a/src/conversions/angle_to_dcm.jl +++ b/src/conversions/angle_to_dcm.jl @@ -57,26 +57,26 @@ DCM{Float64}: 0.707107 0.612372 0.353553 ``` """ -function angle_to_dcm(θ::Number, rot_seq::Symbol) +function angle_to_dcm(θ::T, rot_seq::Symbol) where {T<:Number} sa, ca = sincos(θ) if rot_seq == :X return DCM( - 1, 0, 0, - 0, +ca, +sa, - 0, -sa, +ca + T(1), T(0), T(0), + T(0), +ca, +sa, + T(0), -sa, +ca )' elseif rot_seq == :Y return DCM( - +ca, 0, -sa, - 0, 1, 0, - +sa, 0, +ca + +ca, T(0), -sa, + T(0), T(1), T(0), + +sa, T(0), +ca )' elseif rot_seq == :Z return DCM( - +ca, +sa, 0, - -sa, +ca, 0, - 0, 0, 1 + +ca, +sa, T(0), + -sa, +ca, T(0), + T(0), T(0), T(1) )' else throw(ArgumentError("rot_seq must be :X, :Y, or :Z")) diff --git a/src/conversions/smallangle_to_dcm.jl b/src/conversions/smallangle_to_dcm.jl index 2b92bb1..30630f8 100644 --- a/src/conversions/smallangle_to_dcm.jl +++ b/src/conversions/smallangle_to_dcm.jl @@ -35,16 +35,20 @@ DCM{Float64}: θx::T1, θy::T2, θz::T3; - normalize = true + normalize::Bool = true ) where {T1<:Number, T2<:Number, T3<:Number} # Since we might orthonormalize `D`, we need to get the float to avoid type # instabilities. - T = float(promote_type(T1, T2, T3)) + T = promote_type(T1, T2, T3) - D = DCM{T}( - 1, +θz, -θy, - -θz, 1, +θx, - +θy, -θx, 1 + θx = T(θx) + θy = T(θy) + θz = T(θz) + + D = DCM( + T(1), +θz, -θy, + -θz, T(1), +θx, + +θy, -θx, T(1) )' if normalize @@ -53,3 +57,12 @@ DCM{Float64}: return D end end + +@inline function smallangle_to_dcm( + θx::Integer, + θy::Integer, + θz::Integer; + normalize::Bool = true +) + return smallangle_to_dcm(float(θx), float(θy), float(θz); normalize = normalize) +end \ No newline at end of file diff --git a/src/dcm.jl b/src/dcm.jl index f6aaacb..8d7b96a 100644 --- a/src/dcm.jl +++ b/src/dcm.jl @@ -58,9 +58,9 @@ julia> orthonormalize(D) ``` """ function orthonormalize(dcm::DCM) - e₁ = dcm[:, 1] - e₂ = dcm[:, 2] - e₃ = dcm[:, 3] + e₁ = SVector{3}(dcm[:, 1]) + e₂ = SVector{3}(dcm[:, 2]) + e₃ = SVector{3}(dcm[:, 3]) en₁ = e₁ / norm(e₁) enj₂ = e₂ - (en₁ ⋅ e₂) * en₁ diff --git a/test/differentiability/dcm.jl b/test/differentiability/dcm.jl new file mode 100644 index 0000000..41dae66 --- /dev/null +++ b/test/differentiability/dcm.jl @@ -0,0 +1,34 @@ +## Description ############################################################################# +# +# Test Functions for DCM Zygote Extension +# +############################################################################################ + + +@testset "Test DCM Zygote Differentiation" begin + + data = [0.9071183, -0.38511035, 0.1697833, -0.18077055, 0.0077917147, 0.98349446, -0.38007677, -0.9228377, -0.06254859] + + f, ad = value_and_jacobian(DCM, AutoZygote(), data) + + expected_f = DCM(data) + expected_jac = I(9) + + @test f == expected_f + @test ad == expected_jac + + data_tuple = (data...,) + + ad_jac = reduce(hcat, Zygote.jacobian(DCM, data_tuple...)) + + @test ad_jac == expected_jac + + f_fd, df_fd = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoFiniteDiff(), data) + f_ad, df_ad = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoZygote(), data) + + f_adm, df_adm = value_and_jacobian((x) -> Array(orthonormalize(DCM(x))), AutoMooncake(;config=nothing), data) + + @test f_ad == f_fd + @test df_ad ≈ df_fd rtol=1e-7 + +end diff --git a/test/runtests.jl b/test/runtests.jl index e21a83c..30728fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,8 @@ using StaticArrays import Base: isapprox +using DifferentiationInterface, FiniteDiff, Mooncake, Zygote + ############################################################################################ # Auxiliary Functions # ############################################################################################ @@ -134,3 +136,8 @@ println("") include("./random.jl") end println("") + +@time @testset "Test DCM Differentiation" begin + include("differentiability/dcm.jl") +end +println("")