Skip to content

Commit dbadb67

Browse files
kshyattwsmoses
andauthored
Add easy_rule for matrix det (#2725)
* Add easy_rule for matrix det * fix * fix * efficiency * more efficiency --------- Co-authored-by: William S. Moses <gh@wsmoses.com>
1 parent 3e48bb3 commit dbadb67

File tree

8 files changed

+122
-15
lines changed

8 files changed

+122
-15
lines changed

lib/EnzymeTestUtils/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeTestUtils"
22
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
33
authors = ["Seth Axen <seth@sethaxen.com>", "William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
4-
version = "0.2.6"
4+
version = "0.2.7"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

lib/EnzymeTestUtils/src/EnzymeTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include("test_approx.jl")
1515
include("compatible_activities.jl")
1616
include("finite_difference_calls.jl")
1717
include("generate_tangent.jl")
18+
include("test_utils.jl")
1819
include("test_forward.jl")
1920
include("test_reverse.jl")
2021

lib/EnzymeTestUtils/src/test_forward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ function test_forward(
6363
testset_name=nothing,
6464
runtime_activity::Bool=false
6565
)
66-
call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
67-
call_with_kwargs(f, xs...) = f(xs...; fkwargs...)
66+
call_with_copy = CallWithCopyKWargs(fkwargs)
67+
call_with_kwargs = CallWithKWargs(fkwargs)
6868
if testset_name === nothing
6969
testset_name = "test_forward: $f with return activity $ret_activity on $(_string_activity(args))"
7070
end
7171
@testset "$testset_name" begin
7272
# format arguments for autodiff and FiniteDifferences
7373
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
74-
primals = map(x -> x.val, activities)
74+
primals = map(get_primal, activities)
7575
# call primal, avoid mutating original arguments
7676
fcopy = deepcopy(first(primals))
7777
args_copy = deepcopy(Base.tail(primals))

lib/EnzymeTestUtils/src/test_reverse.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ end
5858
5959
Here we test a rule for a function of an array in batch reverse-mode:
6060
61-
62-
6361
```julia
6462
x = randn(3)
6563
for Tret in (Const, Active), Tx in (Const, BatchDuplicated)
@@ -79,14 +77,14 @@ function test_reverse(
7977
testset_name=nothing,
8078
runtime_activity::Bool=false,
8179
output_tangent=nothing)
82-
call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...)
80+
call_with_captured_kwargs = CallWithKWargs(fkwargs)
8381
if testset_name === nothing
8482
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
8583
end
8684
@testset "$testset_name" begin
8785
# format arguments for autodiff and FiniteDifferences
8886
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
89-
primals = map(x -> x.val, activities)
87+
primals = map(get_primal, activities)
9088
# call primal, avoid mutating original arguments
9189
fcopy = deepcopy(first(primals))
9290
args_copy = deepcopy(Base.tail(primals))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
struct CallWithKWargs{KW}
3+
kwargs::KW
4+
end
5+
6+
function (c::CallWithKWargs)(f, xs...)
7+
f(xs...; c.kwargs...)
8+
end
9+
10+
struct CallWithCopyKWargs{KW}
11+
kwargs::KW
12+
end
13+
14+
function (c::CallWithCopyKWargs)(f, xs...)
15+
deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...)
16+
end
17+
18+
@inline function get_primal(x::Annotation)
19+
x.val
20+
end

src/internal_rules.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ function EnzymeRules.augmented_primal(
793793
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
794794
end
795795

796-
# This is required to handle arugments that mix real and complex numbers
796+
# This is required to handle arguments that mix real and complex numbers
797797
_project(::Type{<:Real}, x) = x
798798
_project(::Type{<:Real}, x::Complex) = real(x)
799799
_project(::Type{<:Complex}, x) = x
@@ -922,11 +922,27 @@ function EnzymeRules.reverse(
922922
return (nothing, nothing, nothing, dα, dβ)
923923
end
924924

925+
function cofactor(A)
926+
cofA = similar(A)
927+
minorAij = similar(A, size(A, 1) - 1, size(A, 2) - 1)
928+
for i in 1:size(A, 1), j in 1:size(A, 2)
929+
fill!(minorAij, zero(eltype(A)))
925930

931+
# build minor matrix
932+
for k in 1:size(A, 1), l in 1:size(A, 2)
933+
if !(k == i || l == j)
934+
ki = k < i ? k : k - 1
935+
li = l < j ? l : l - 1
936+
@inbounds minorAij[ki, li] = A[k, l]
937+
end
938+
end
939+
@inbounds cofA[i, j] = (-1)^(i - 1 + j - 1) * det(minorAij)
940+
end
941+
return cofA
942+
end
926943

927-
928-
929-
944+
# partial derivative of the determinant is the matrix of cofactors
945+
EnzymeRules.@easy_rule(LinearAlgebra.det(A::AbstractMatrix), (cofactor(A),))
930946

931947
function EnzymeRules.forward(
932948
config::EnzymeRules.FwdConfig,

src/rules/customrules.jl

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ import LinearAlgebra
4848
:(LinearAlgebra.dot(partial,dx))
4949
elseif dx <: AbstractFloat || dx <: AbstractArray{<:AbstractFloat}
5050
:(LinearAlgebra.dot(dx, partial))
51-
else
51+
elseif partial <: AbstractVector
5252
:(LinearAlgebra.dot(adjoint(partial),dx))
53+
else
54+
:(LinearAlgebra.dot(conj(partial),dx))
5355
end
5456
return quote
5557
Base.@_inline_meta
@@ -106,10 +108,52 @@ import LinearAlgebra
106108
end
107109
end
108110

109-
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial, dx)
111+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Real, dx)
112+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
113+
end
114+
115+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Complex, dx)
116+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
117+
end
118+
119+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::Number)
120+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
121+
end
122+
123+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::Number)
124+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
125+
end
126+
127+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where N
128+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
129+
end
130+
131+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where N
132+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
133+
end
134+
135+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractVector{<:Complex}, dx::AbstractVector{<:Any})
110136
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
111137
end
112138

139+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Real}, dx::AbstractVector)
140+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx)
141+
end
142+
143+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Complex}, dx::AbstractVector)
144+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
145+
end
146+
147+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::AbstractArray)
148+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx)
149+
end
150+
151+
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::AbstractArray)
152+
pd = Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...))
153+
Base.conj!(pd)
154+
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx)
155+
end
156+
113157
function enzyme_custom_setup_args(
114158
@nospecialize(B::Union{Nothing, LLVM.IRBuilder}),
115159
orig::LLVM.CallInst,

test/rules/internal_rules.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Enzyme
22
using EnzymeTestUtils
3-
import Random
3+
import Random, LinearAlgebra
44
using Test
55

66
struct TPair
@@ -207,3 +207,31 @@ end
207207
end
208208
end
209209
end
210+
211+
@testset "(matrix) det" begin
212+
@testset "forward" begin
213+
@testset for RT in (Const,DuplicatedNoNeed,Duplicated,),
214+
Tx in (Const,Duplicated,)
215+
xr = [4.0 3.0; 2.0 1.0]
216+
test_forward(LinearAlgebra.det, RT, (xr, Tx))
217+
218+
xc = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
219+
test_forward(LinearAlgebra.det, RT, (xc, Tx))
220+
end
221+
end
222+
@testset "reverse" begin
223+
@testset for RT in (Const, Active,), Tx in (Const, Duplicated,)
224+
225+
# TODO see https://github.com/EnzymeAD/Enzyme/issues/2537
226+
if RT <: Const
227+
continue
228+
end
229+
230+
x = [4.0 3.0; 2.0 1.0]
231+
test_reverse(LinearAlgebra.det, RT, (x, Tx))
232+
233+
x = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
234+
test_reverse(LinearAlgebra.det, RT, (x, Tx))
235+
end
236+
end
237+
end

0 commit comments

Comments
 (0)