Skip to content

Commit f71464c

Browse files
committed
Add test with rotations
1 parent 49bb65e commit f71464c

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

test/JuliaTest.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ end
170170
@test jgrad(Â, B̂, x̂) ForwardDiff.gradient(x -> f(Â, B̂, x), x̂)
171171
end
172172

173+
@testset "gradient of x' * sin.(A * x)" begin
174+
f(A, x) = x' * sin.(A * x)
175+
176+
jgrad = eval(to_std(gradient(f(A, x), x); format = dc.JuliaFunc()))
177+
178+
@test jgrad(Â, x̂) ForwardDiff.gradient(x -> f(Â, x), x̂)
179+
end
180+
173181
@testset "jacobian of sin(A * x + y - z)" begin
174182
f(A, x, y, z) = sin.(A * x + y - z)
175183

@@ -201,4 +209,12 @@ end
201209

202210
@test jhess(x̂) ForwardDiff.hessian(x -> f(x), x̂)
203211
end
212+
213+
@testset "hessian of x' * sin.(A * x)" begin
214+
f(A, x) = x' * sin.(A * x)
215+
216+
jhess = eval(to_std(hessian(f(A, x), x); format = dc.JuliaFunc()))
217+
218+
@test jhess(Â, x̂) ForwardDiff.hessian(x -> f(Â, x), x̂)
219+
end
204220
end

test/StdStrTest.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ using LinearAlgebra: tr, diag, diagm, norm
5555
@test to_std(gradient((x .^ 2 .* y)' * c, x)) == "2(x ⊙ y ⊙ c)"
5656
@test to_std(gradient(norm(A * x, 2), x)) == "1/2sum((xᵀAᵀ)²)⁻¹⸍²2AᵀAx"
5757
@test to_std(gradient(log.(A*x)' * x, x)) == "Aᵀdiagm(vec(1) ⊘ (Ax))x + log(Ax)" # TODO: Simplify diagm(quotient) * vec
58+
@test to_std(gradient(x' * sin.(A * x), x)) == "Aᵀdiagm(cos(Ax))x + sin(Ax)"
5859
end
5960

6061
@testset "test Jacobian in standard notation" begin
@@ -91,4 +92,6 @@ end
9192
@test to_std(hessian(2 * x' * x, x)) == "4I"
9293
@test to_std(hessian(sin(cos(x' * A * B' * x)), x)) ==
9394
"cos(cos(xᵀBAᵀx))((-1)sin(xᵀBAᵀx)(BAᵀ + ABᵀ) + (BAᵀx + ABᵀx)(-1)cos(xᵀBAᵀx)(xᵀABᵀ + xᵀBAᵀ)) + (-1)(-1)sin(xᵀBAᵀx)(BAᵀx + ABᵀx)sin(cos(xᵀBAᵀx))(-1)sin(xᵀBAᵀx)(xᵀABᵀ + xᵀBAᵀ)"
95+
@test to_std(hessian(x' * sin.(A * x), x)) ==
96+
"diagm(cos(Ax))A + Aᵀdiagm(cos(Ax)) + (-1)Aᵀdiagm(x ⊙ sin(Ax))A"
9497
end

0 commit comments

Comments
 (0)