Skip to content

Feat Zygote Extension #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4bca650
feat: Zygote Extension for DCM constructor
jmurphy6895 Nov 16, 2024
7871319
chore: more zygote adjoints
jmurphy6895 Nov 18, 2024
2bf4a7c
chore: fix Zygote extension
jmurphy6895 Dec 15, 2024
6ff2ed7
chore: fix for orthonormalize function
jmurphy6895 Dec 16, 2024
36cd5dd
chore: additional rrule for DCM zygote
jmurphy6895 Dec 16, 2024
24085db
chore: change Zygote construction
jmurphy6895 Dec 16, 2024
d4ca8e9
chore: fix typing
jmurphy6895 Dec 16, 2024
df0f104
chore: more ext changes
jmurphy6895 Dec 16, 2024
30697f6
chore: another attempt at zygote
jmurphy6895 Dec 16, 2024
5bf8cb4
chore: another attempt
jmurphy6895 Dec 16, 2024
6fedb2e
chore: try loosening type
jmurphy6895 Dec 16, 2024
47aa9f8
chore: another attempt
jmurphy6895 Dec 16, 2024
e8899f9
chore: return type change
jmurphy6895 Dec 16, 2024
6f16814
chore: separate DCM matrix constuctors
jmurphy6895 Dec 16, 2024
8358198
chore: fix definition
jmurphy6895 Dec 16, 2024
70b530b
chore: try to fix Matrix Error
jmurphy6895 Dec 16, 2024
29e47b6
chore: try simplifying return
jmurphy6895 Dec 16, 2024
528459e
chore: new attempt
jmurphy6895 Dec 16, 2024
a0180ef
chore: new attempt
jmurphy6895 Dec 16, 2024
87507f5
chore: try simplifying Zygote ext
jmurphy6895 Dec 16, 2024
22f581d
chore: new approach, pass constuctor a tuple
jmurphy6895 Dec 16, 2024
5336d4d
chore: try expanding Tuple case
jmurphy6895 Dec 16, 2024
3a63a0f
chore: fix typing
jmurphy6895 Dec 16, 2024
3af9e74
chore: promotion changes
jmurphy6895 Dec 17, 2024
96e9577
fix: Zygote & Mooncake compatibility in orthonormalize
jmurphy6895 Dec 31, 2024
419fefd
fix: add second zygote constructor
jmurphy6895 Dec 31, 2024
16f2174
fix: small changes to zygote constructor
jmurphy6895 Dec 31, 2024
37f3539
fix: remove extra constuctor, small typing change in angle_to_dcm
jmurphy6895 Dec 31, 2024
5c30617
chore: fix typing bug
jmurphy6895 Jan 5, 2025
52e163f
chore: replace orthonormalize zygote headache with forwarddiff
jmurphy6895 Jan 8, 2025
91d221c
chore: fix gradient size
jmurphy6895 Jan 8, 2025
a477368
chore: small code clean up
jmurphy6895 Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,32 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ReferenceFrameRotationsZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
Crayons = "4.0"
DifferentiationInterface = "0.6"
FiniteDiff = "2.26"
ForwardDiff = "0.10"
LinearAlgebra = "1.6"
Mooncake = "0.4"
Printf = "1.6"
Random = "1.6"
StaticArrays = "1"
Test = "1.6"
julia = "1.6"
Zygote = "0.6"
julia = "1.10, 1.11"

[extras]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "DifferentiationInterface", "FiniteDiff", "Mooncake", "Zygote"]
38 changes: 38 additions & 0 deletions ext/ReferenceFrameRotationsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module ReferenceFrameRotationsZygoteExt

using ReferenceFrameRotations
using ForwardDiff

using Zygote.ChainRulesCore: ChainRulesCore
import Zygote.ChainRulesCore: NoTangent

function ChainRulesCore.rrule(
::Type{<:DCM}, data::NTuple{9, T}
) where {T}

y = DCM(data)

function DCM_pullback(Δ)
return (NoTangent(), Tuple(Δ))
end

return y, DCM_pullback
end

function ChainRulesCore.rrule(
::typeof(orthonormalize), dcm::DCM
)

y = orthonormalize(dcm)

function orthonormalize_pullback(Δ)

jac = ForwardDiff.jacobian(orthonormalize, dcm)

return (NoTangent(), reshape(vcat(Δ...)' * jac, 3, 3))
end

return y, orthonormalize_pullback
end

end
20 changes: 10 additions & 10 deletions src/conversions/angle_to_dcm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,26 @@ DCM{Float64}:
0.707107 0.612372 0.353553
```
"""
function angle_to_dcm(θ::Number, rot_seq::Symbol)
function angle_to_dcm(θ::T, rot_seq::Symbol) where {T<:Number}
sa, ca = sincos(θ)

if rot_seq == :X
return DCM(
1, 0, 0,
0, +ca, +sa,
0, -sa, +ca
T(1), T(0), T(0),
T(0), +ca, +sa,
T(0), -sa, +ca
)'
elseif rot_seq == :Y
return DCM(
+ca, 0, -sa,
0, 1, 0,
+sa, 0, +ca
+ca, T(0), -sa,
T(0), T(1), T(0),
+sa, T(0), +ca
)'
elseif rot_seq == :Z
return DCM(
+ca, +sa, 0,
-sa, +ca, 0,
0, 0, 1
+ca, +sa, T(0),
-sa, +ca, T(0),
T(0), T(0), T(1)
)'
else
throw(ArgumentError("rot_seq must be :X, :Y, or :Z"))
Expand Down
25 changes: 19 additions & 6 deletions src/conversions/smallangle_to_dcm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@ DCM{Float64}:
θx::T1,
θy::T2,
θz::T3;
normalize = true
normalize::Bool = true
) where {T1<:Number, T2<:Number, T3<:Number}
# Since we might orthonormalize `D`, we need to get the float to avoid type
# instabilities.
T = float(promote_type(T1, T2, T3))
T = promote_type(T1, T2, T3)

D = DCM{T}(
1, +θz, -θy,
-θz, 1, +θx,
+θy, -θx, 1
θx = T(θx)
θy = T(θy)
θz = T(θz)

D = DCM(
T(1), +θz, -θy,
-θz, T(1), +θx,
+θy, -θx, T(1)
)'

if normalize
Expand All @@ -53,3 +57,12 @@ DCM{Float64}:
return D
end
end

@inline function smallangle_to_dcm(
θx::Integer,
θy::Integer,
θz::Integer;
normalize::Bool = true
)
return smallangle_to_dcm(float(θx), float(θy), float(θz); normalize = normalize)
end
6 changes: 3 additions & 3 deletions src/dcm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ julia> orthonormalize(D)
```
"""
function orthonormalize(dcm::DCM)
e₁ = dcm[:, 1]
e₂ = dcm[:, 2]
e₃ = dcm[:, 3]
e₁ = SVector{3}(dcm[:, 1])
e₂ = SVector{3}(dcm[:, 2])
e₃ = SVector{3}(dcm[:, 3])

en₁ = e₁ / norm(e₁)
enj₂ = e₂ - (en₁ ⋅ e₂) * en₁
Expand Down
34 changes: 34 additions & 0 deletions test/differentiability/dcm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## Description #############################################################################
#
# Test Functions for DCM Zygote Extension
#
############################################################################################


@testset "Test DCM Zygote Differentiation" begin

data = [0.9071183, -0.38511035, 0.1697833, -0.18077055, 0.0077917147, 0.98349446, -0.38007677, -0.9228377, -0.06254859]

f, ad = value_and_jacobian(DCM, AutoZygote(), data)

expected_f = DCM(data)
expected_jac = I(9)

@test f == expected_f
@test ad == expected_jac

data_tuple = (data...,)

ad_jac = reduce(hcat, Zygote.jacobian(DCM, data_tuple...))

@test ad_jac == expected_jac

f_fd, df_fd = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoFiniteDiff(), data)
f_ad, df_ad = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoZygote(), data)

f_adm, df_adm = value_and_jacobian((x) -> Array(orthonormalize(DCM(x))), AutoMooncake(;config=nothing), data)

@test f_ad == f_fd
@test df_ad ≈ df_fd rtol=1e-7

end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using StaticArrays

import Base: isapprox

using DifferentiationInterface, FiniteDiff, Mooncake, Zygote

############################################################################################
# Auxiliary Functions #
############################################################################################
Expand Down Expand Up @@ -134,3 +136,8 @@ println("")
include("./random.jl")
end
println("")

@time @testset "Test DCM Differentiation" begin
include("differentiability/dcm.jl")
end
println("")