Skip to content

Commit a7991dd

Browse files
authored
Add derivative rule for LinearAlgebra.givensAlgorithm (#783)
1 parent c2ec27f commit a7991dd

File tree

5 files changed

+108
-1
lines changed

5 files changed

+108
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "1.2.2"
3+
version = "1.3.0"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"

src/dual.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,57 @@ end
735735
return (Dual{T}(sd, cd * π * partials(d)), Dual{T}(cd, -sd * π * partials(d)))
736736
end
737737

738+
# LinearAlgebra.givensAlgorithm #
739+
#-------------------------------#
740+
741+
# This definition ensures that we match `LinearAlgebra.givensAlgorithm`
742+
# for non-dual numbers (i.e., `ForwardDiff.Dual` with zero partials)
743+
# `LinearAlgebra.givensAlgorithm` is derived from LAPACK's dlartg
744+
# which is [documented](https://netlib.org/lapack/explore-html/da/dd3/group__lartg_ga86f8f877eaea0386cdc2c3c175d9ea88.html) to return
745+
# three values c, s, u for two arguments x and y with
746+
# u = sgn(x) sqrt(x^2 + y^2)
747+
# c = x/u
748+
# s = y/u
749+
# The function is discontinuous in u at x=0
750+
@define_binary_dual_op(
751+
LinearAlgebra.givensAlgorithm,
752+
begin
753+
vx, vy = value(x), value(y)
754+
c, s, u = LinearAlgebra.givensAlgorithm(vx, vy)
755+
∂c∂x = s^2 / u
756+
∂c∂y = ∂s∂x = -(c * s / u)
757+
∂s∂y = c^2 / u
758+
∂x = partials(x)
759+
∂y = partials(y)
760+
∂c = _mul_partials(∂x, ∂y, ∂c∂x, ∂c∂y)
761+
∂s = _mul_partials(∂x, ∂y, ∂s∂x, ∂s∂y)
762+
∂u = _mul_partials(∂x, ∂y, c, s)
763+
return Dual{Txy}(c, ∂c), Dual{Txy}(s, ∂s), Dual{Txy}(u, ∂u)
764+
end,
765+
begin
766+
vx = value(x)
767+
c, s, u = LinearAlgebra.givensAlgorithm(vx, y)
768+
∂c∂x = s^2 / u
769+
∂s∂x = -(c * s / u)
770+
∂x = partials(x)
771+
∂c = ∂c∂x * ∂x
772+
∂s = ∂s∂x * ∂x
773+
∂u = c * ∂x
774+
return Dual{Tx}(c, ∂c), Dual{Tx}(s, ∂s), Dual{Tx}(u, ∂u)
775+
end,
776+
begin
777+
vy = value(y)
778+
c, s, u = LinearAlgebra.givensAlgorithm(x, vy)
779+
∂c∂y = -(c * s / u)
780+
∂s∂y = c^2 / u
781+
∂y = partials(y)
782+
∂c = ∂c∂y * ∂y
783+
∂s = ∂s∂y * ∂y
784+
∂u = s * ∂y
785+
return Dual{Ty}(c, ∂c), Dual{Ty}(s, ∂s), Dual{Ty}(u, ∂u)
786+
end,
787+
)
788+
738789
# Symmetric eigvals #
739790
#-------------------#
740791

test/DerivativeTest.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DerivativeTest
22

33
import Calculus
4+
import LinearAlgebra
45
import NaNMath
56

67
using Test
@@ -122,4 +123,14 @@ end
122123
end
123124
end
124125

126+
@testset "Givens rotations: Derivatives" begin
127+
# Test different branches in `LinearAlgebra.givensAlgorithm`
128+
for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f], i in 1:3
129+
@test ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f)
130+
Calculus.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f)
131+
@test ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g)
132+
Calculus.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g)
133+
end
134+
end
135+
125136
end # module

test/DualTest.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using NaNMath, SpecialFunctions, LogExpFunctions
1010
using DiffRules
1111

1212
import Calculus
13+
import LinearAlgebra
1314

1415
struct TestTag end
1516
struct OuterTestTag end
@@ -685,4 +686,31 @@ end
685686
@test ForwardDiff.derivative(x -> sum(1 .+ x .* (0:0.1:1)), 1) == 5.5
686687
end
687688

689+
@testset "Givens rotations: consistency with `LinearAlgebra.givensAlgorithm` for zero partials (no duals)" begin
690+
# Test different branches in `LinearAlgebra.givensAlgorithm`
691+
for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f]
692+
# Upstream: Result for non-dual numbers
693+
y = LinearAlgebra.givensAlgorithm(f, g)
694+
@test y isa NTuple{3,Float64}
695+
696+
for n in (1, 2, 5)
697+
zero_tuple = ntuple(Returns(0.0), n)
698+
dual_f = Dual{TestTag}(f, zero_tuple)
699+
dual_g = Dual{TestTag}(g, zero_tuple)
700+
for (_f, _g) in ((dual_f, dual_g), (dual_f, g), (f, dual_g))
701+
ydual = @inferred(LinearAlgebra.givensAlgorithm(_f, _g))
702+
@test ydual isa NTuple{3,Dual{TestTag,Float64,n}}
703+
704+
for (i, yi, yduali) in zip(1:3, y, ydual)
705+
# Primal values must match `LinearAlgebra.givensAlgorithm` with `Float64` inputs
706+
@test ForwardDiff.value(yduali) yi
707+
708+
# Partial derivatives must be zero (zero in - zero out)
709+
@test iszero(ForwardDiff.partials(yduali))
710+
end
711+
end
712+
end
713+
end
714+
end
715+
688716
end # module

test/GradientTest.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module GradientTest
22

33
import Calculus
4+
import LinearAlgebra
45
import NaNMath
56

67
using Test
@@ -330,4 +331,20 @@ end
330331
end
331332
end
332333

334+
@testset "Givens rotations: Gradients" begin
335+
# Test different branches in `LinearAlgebra.givensAlgorithm`
336+
for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f], i in 1:3
337+
# Gradients wrt to a single input argument
338+
dydf = only(ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(only(x), g)[i], [f]))
339+
@test dydf == ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f)
340+
dydg = only(ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(f, only(x))[i], [g]))
341+
@test dydg == ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g)
342+
343+
# Gradient with respect to both input arguments
344+
grad = ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(x[1], x[2])[i], [f, g])
345+
@test grad == [dydf, dydg]
346+
@test grad Calculus.gradient(x -> LinearAlgebra.givensAlgorithm(x[1], x[2])[i], [f, g])
347+
end
348+
end
349+
333350
end # module

0 commit comments

Comments
 (0)