Skip to content

Commit 27de6e5

Browse files
authored
✨ Add Zygote extension
1 parent c32791f commit 27de6e5

File tree

7 files changed

+128
-21
lines changed

7 files changed

+128
-21
lines changed

Project.toml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,32 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414

15+
[weakdeps]
16+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
17+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
18+
19+
[extensions]
20+
ReferenceFrameRotationsZygoteExt = ["ForwardDiff", "Zygote"]
21+
1522
[compat]
1623
Crayons = "4.0"
24+
DifferentiationInterface = "0.6"
25+
FiniteDiff = "2.26"
26+
ForwardDiff = "0.10"
1727
LinearAlgebra = "1.6"
28+
Mooncake = "0.4"
1829
Printf = "1.6"
1930
Random = "1.6"
2031
StaticArrays = "1"
2132
Test = "1.6"
22-
julia = "1.6"
33+
Zygote = "0.6"
34+
julia = "1.10, 1.11"
2335

2436
[extras]
37+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
38+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
39+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2540
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2641

2742
[targets]
28-
test = ["Test"]
43+
test = ["Test", "DifferentiationInterface", "FiniteDiff", "Mooncake", "Zygote"]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module ReferenceFrameRotationsZygoteExt
2+
3+
using ReferenceFrameRotations
4+
using ForwardDiff
5+
6+
using Zygote.ChainRulesCore: ChainRulesCore
7+
import Zygote.ChainRulesCore: NoTangent
8+
9+
function ChainRulesCore.rrule(
10+
::Type{<:DCM}, data::NTuple{9, T}
11+
) where {T}
12+
13+
y = DCM(data)
14+
15+
function DCM_pullback(Δ)
16+
return (NoTangent(), Tuple(Δ))
17+
end
18+
19+
return y, DCM_pullback
20+
end
21+
22+
function ChainRulesCore.rrule(
23+
::typeof(orthonormalize), dcm::DCM
24+
)
25+
26+
y = orthonormalize(dcm)
27+
28+
function orthonormalize_pullback(Δ)
29+
30+
jac = ForwardDiff.jacobian(orthonormalize, dcm)
31+
32+
return (NoTangent(), reshape(vcat...)' * jac, 3, 3))
33+
end
34+
35+
return y, orthonormalize_pullback
36+
end
37+
38+
end

src/conversions/angle_to_dcm.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,26 @@ DCM{Float64}:
5757
0.707107 0.612372 0.353553
5858
```
5959
"""
60-
function angle_to_dcm::Number, rot_seq::Symbol)
60+
function angle_to_dcm::T, rot_seq::Symbol) where {T<:Number}
6161
sa, ca = sincos(θ)
6262

6363
if rot_seq == :X
6464
return DCM(
65-
1, 0, 0,
66-
0, +ca, +sa,
67-
0, -sa, +ca
65+
T(1), T(0), T(0),
66+
T(0), +ca, +sa,
67+
T(0), -sa, +ca
6868
)'
6969
elseif rot_seq == :Y
7070
return DCM(
71-
+ca, 0, -sa,
72-
0, 1, 0,
73-
+sa, 0, +ca
71+
+ca, T(0), -sa,
72+
T(0), T(1), T(0),
73+
+sa, T(0), +ca
7474
)'
7575
elseif rot_seq == :Z
7676
return DCM(
77-
+ca, +sa, 0,
78-
-sa, +ca, 0,
79-
0, 0, 1
77+
+ca, +sa, T(0),
78+
-sa, +ca, T(0),
79+
T(0), T(0), T(1)
8080
)'
8181
else
8282
throw(ArgumentError("rot_seq must be :X, :Y, or :Z"))

src/conversions/smallangle_to_dcm.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,20 @@ DCM{Float64}:
3535
θx::T1,
3636
θy::T2,
3737
θz::T3;
38-
normalize = true
38+
normalize::Bool = true
3939
) where {T1<:Number, T2<:Number, T3<:Number}
4040
# Since we might orthonormalize `D`, we need to get the float to avoid type
4141
# instabilities.
42-
T = float(promote_type(T1, T2, T3))
42+
T = promote_type(T1, T2, T3)
4343

44-
D = DCM{T}(
45-
1, +θz, -θy,
46-
-θz, 1, +θx,
47-
+θy, -θx, 1
44+
θx = T(θx)
45+
θy = T(θy)
46+
θz = T(θz)
47+
48+
D = DCM(
49+
T(1), +θz, -θy,
50+
-θz, T(1), +θx,
51+
+θy, -θx, T(1)
4852
)'
4953

5054
if normalize
@@ -53,3 +57,12 @@ DCM{Float64}:
5357
return D
5458
end
5559
end
60+
61+
@inline function smallangle_to_dcm(
62+
θx::Integer,
63+
θy::Integer,
64+
θz::Integer;
65+
normalize::Bool = true
66+
)
67+
return smallangle_to_dcm(float(θx), float(θy), float(θz); normalize = normalize)
68+
end

src/dcm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ julia> orthonormalize(D)
5858
```
5959
"""
6060
function orthonormalize(dcm::DCM)
61-
e₁ = dcm[:, 1]
62-
e₂ = dcm[:, 2]
63-
e₃ = dcm[:, 3]
61+
e₁ = SVector{3}(dcm[:, 1])
62+
e₂ = SVector{3}(dcm[:, 2])
63+
e₃ = SVector{3}(dcm[:, 3])
6464

6565
en₁ = e₁ / norm(e₁)
6666
enj₂ = e₂ - (en₁ e₂) * en₁

test/differentiability/dcm.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## Description #############################################################################
2+
#
3+
# Test Functions for DCM Zygote Extension
4+
#
5+
############################################################################################
6+
7+
8+
@testset "Test DCM Zygote Differentiation" begin
9+
10+
data = [0.9071183, -0.38511035, 0.1697833, -0.18077055, 0.0077917147, 0.98349446, -0.38007677, -0.9228377, -0.06254859]
11+
12+
f, ad = value_and_jacobian(DCM, AutoZygote(), data)
13+
14+
expected_f = DCM(data)
15+
expected_jac = I(9)
16+
17+
@test f == expected_f
18+
@test ad == expected_jac
19+
20+
data_tuple = (data...,)
21+
22+
ad_jac = reduce(hcat, Zygote.jacobian(DCM, data_tuple...))
23+
24+
@test ad_jac == expected_jac
25+
26+
f_fd, df_fd = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoFiniteDiff(), data)
27+
f_ad, df_ad = value_and_jacobian((x) -> orthonormalize(DCM(x)), AutoZygote(), data)
28+
29+
f_adm, df_adm = value_and_jacobian((x) -> Array(orthonormalize(DCM(x))), AutoMooncake(;config=nothing), data)
30+
31+
@test f_ad == f_fd
32+
@test df_ad df_fd rtol=1e-7
33+
34+
end

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using StaticArrays
66

77
import Base: isapprox
88

9+
using DifferentiationInterface, FiniteDiff, Mooncake, Zygote
10+
911
############################################################################################
1012
# Auxiliary Functions #
1113
############################################################################################
@@ -134,3 +136,8 @@ println("")
134136
include("./random.jl")
135137
end
136138
println("")
139+
140+
@time @testset "Test DCM Differentiation" begin
141+
include("differentiability/dcm.jl")
142+
end
143+
println("")

0 commit comments

Comments
 (0)