Skip to content

Commit f9f6fbf

Browse files
authored
just test finiteDiff vs Zygote
1 parent 1909958 commit f9f6fbf

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

test/adjoint.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
4444
@test dA dA2
4545
@test db1 db12
4646

47+
# Test complex numbers
48+
A = rand(n, n) + 1im*rand(n, n);
49+
b1 = rand(n) + 1im*rand(n, n);
50+
4751
function f3(A, b1, b2; alg = KrylovJL_GMRES())
4852
prob = LinearProblem(A, b1)
4953
sol1 = solve(prob, alg)
@@ -66,6 +70,9 @@ db22 = FiniteDiff.finite_difference_gradient(
6670
@test db1 db12
6771
@test db2 db22
6872

73+
A = rand(n, n);
74+
b1 = rand(n);
75+
6976
function f4(A, b1, b2; alg = LUFactorization())
7077
prob = LinearProblem(A, b1)
7178
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
@@ -85,9 +92,8 @@ db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))
8592
@test db1 db12
8693
@test db2 db22
8794

88-
# Test complex numbers
89-
A = rand(n, n) + 1im * rand(n, n);
90-
b1 = rand(n) + 1im * rand(n);
95+
A = rand(n, n);
96+
b1 = rand(n);
9197
for alg in (
9298
LUFactorization(),
9399
RFLUFactorization(),
@@ -99,7 +105,7 @@ for alg in (
99105

100106
sol1 = solve(prob, alg)
101107

102-
sum(abs2.(sol1.u))
108+
sum(sol1.u)
103109
end
104110
fb(b1)
105111

@@ -116,7 +122,7 @@ for alg in (
116122

117123
sol1 = solve(prob, alg)
118124

119-
sum(abs2.(sol1.u))
125+
sum(sol1.u)
120126
end
121127
fA(A)
122128

0 commit comments

Comments
 (0)