Skip to content

Commit fc20e38

Browse files
committed
fix
1 parent 29c4dce commit fc20e38

File tree

6 files changed

+80
-11
lines changed

6 files changed

+80
-11
lines changed

lib/EnzymeTestUtils/src/EnzymeTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include("test_approx.jl")
1515
include("compatible_activities.jl")
1616
include("finite_difference_calls.jl")
1717
include("generate_tangent.jl")
18+
include("test_utils.jl")
1819
include("test_forward.jl")
1920
include("test_reverse.jl")
2021

lib/EnzymeTestUtils/src/test_forward.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ function test_forward(
6363
testset_name=nothing,
6464
runtime_activity::Bool=false
6565
)
66-
call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
67-
call_with_kwargs(f, xs...) = f(xs...; fkwargs...)
66+
call_with_copy = CallWithCopyKWargs(fkwargs)
67+
call_with_kwargs = CallWithKWargs(fkwargs)
6868
if testset_name === nothing
6969
testset_name = "test_forward: $f with return activity $ret_activity on $(_string_activity(args))"
7070
end
7171
@testset "$testset_name" begin
7272
# format arguments for autodiff and FiniteDifferences
7373
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
74-
primals = map(x -> x.val, activities)
74+
primals = map(get_primal, activities)
7575
# call primal, avoid mutating original arguments
7676
fcopy = deepcopy(first(primals))
7777
args_copy = deepcopy(Base.tail(primals))
@@ -149,7 +149,7 @@ function test_forward(
149149
end
150150
else
151151
test_approx(
152-
dy_ad, dy_fdm, "derivative should agree with finite differences"; atol, rtol
152+
dy_ad, dy_fdm, "derivative should agree with finite differences ($activities)"; atol, rtol
153153
)
154154
end
155155
end

lib/EnzymeTestUtils/src/test_reverse.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ end
5858
5959
Here we test a rule for a function of an array in batch reverse-mode:
6060
61-
62-
6361
```julia
6462
x = randn(3)
6563
for Tret in (Const, Active), Tx in (Const, BatchDuplicated)
@@ -79,14 +77,14 @@ function test_reverse(
7977
testset_name=nothing,
8078
runtime_activity::Bool=false,
8179
output_tangent=nothing)
82-
call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...)
80+
call_with_captured_kwargs = CallWithKWargs(fkwargs)
8381
if testset_name === nothing
8482
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
8583
end
8684
@testset "$testset_name" begin
8785
# format arguments for autodiff and FiniteDifferences
8886
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
89-
primals = map(x -> x.val, activities)
87+
primals = map(get_primal, activities)
9088
# call primal, avoid mutating original arguments
9189
fcopy = deepcopy(first(primals))
9290
args_copy = deepcopy(Base.tail(primals))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
struct CallWithKWargs{KW}
3+
kwargs::KW
4+
end
5+
6+
function (c::CallWithKWargs)(f, xs...)
7+
f(xs...; c.kwargs...)
8+
end
9+
10+
struct CallWithCopyKWargs{KW}
11+
kwargs::KW
12+
end
13+
14+
function (c::CallWithCopyKWargs)(f, xs...)
15+
deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...)
16+
end
17+
18+
@inline function get_primal(x::Annotation)
19+
x.val
20+
end

src/rules/customrules.jl

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ import LinearAlgebra
4848
:(LinearAlgebra.dot(partial,dx))
4949
elseif dx <: AbstractFloat || dx <: AbstractArray{<:AbstractFloat}
5050
:(LinearAlgebra.dot(dx, partial))
51-
else
51+
elseif partial <: AbstractVector
5252
:(LinearAlgebra.dot(adjoint(partial),dx))
53+
else
54+
:(LinearAlgebra.dot(conj(partial),dx))
5355
end
5456
return quote
5557
Base.@_inline_meta
@@ -106,10 +108,52 @@ import LinearAlgebra
106108
end
107109
end
108110

109-
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial, dx)
111+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Real, dx)
112+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
113+
end
114+
115+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Complex, dx)
116+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
117+
end
118+
119+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::Number)
120+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
121+
end
122+
123+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::Number)
124+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
125+
end
126+
127+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where N
128+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
129+
end
130+
131+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where N
132+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
133+
end
134+
135+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractVector{<:Complex}, dx::AbstractVector{<:Any})
110136
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
111137
end
112138

139+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Real}, dx::AbstractVector)
140+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx)
141+
end
142+
143+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Complex}, dx::AbstractVector)
144+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
145+
end
146+
147+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::AbstractArray)
148+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx)
149+
end
150+
151+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::AbstractArray)
152+
pd = Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...))
153+
Base.conj!(pd)
154+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx)
155+
end
156+
113157
function enzyme_custom_setup_args(
114158
@nospecialize(B::Union{Nothing, LLVM.IRBuilder}),
115159
orig::LLVM.CallInst,

test/rules/internal_rules.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,13 @@ end
220220
end
221221
end
222222
@testset "reverse" begin
223-
@testset for RT in (Const, Active,), Tx in (Const, Duplicated,),
223+
@testset for RT in (Const, Active,), Tx in (Const, Duplicated,)
224+
225+
# TODO see https://github.com/EnzymeAD/Enzyme/issues/2537
226+
if RT <: Const
227+
continue
228+
end
229+
224230
x = [4.0 3.0; 2.0 1.0]
225231
test_reverse(LinearAlgebra.det, RT, (x, Tx))
226232

0 commit comments

Comments
 (0)