Skip to content

Commit 29c4dce

Browse files
committed
Add easy_rule for matrix det
1 parent 107b327 commit 29c4dce

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

src/internal_rules.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ function EnzymeRules.augmented_primal(
793793
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
794794
end
795795

796-
# This is required to handle arugments that mix real and complex numbers
796+
# This is required to handle arguments that mix real and complex numbers
797797
_project(::Type{<:Real}, x) = x
798798
_project(::Type{<:Real}, x::Complex) = real(x)
799799
_project(::Type{<:Complex}, x) = x
@@ -922,11 +922,26 @@ function EnzymeRules.reverse(
922922
return (nothing, nothing, nothing, dα, dβ)
923923
end
924924

925-
926-
927-
928-
929-
925+
function cofactor(A)
926+
cofA = zeros(eltype(A), size(A))
927+
minorAij = zeros(eltype(A), size(A, 1) - 1, size(A, 2) - 1)
928+
for i in 1:size(A, 1), j in 1:size(A, 2)
929+
# build minor matrix
930+
for k in 1:size(A, 1), l in 1:size(A, 2)
931+
if !(k == i || l == j)
932+
ki = k < i ? k : k - 1
933+
li = l < j ? l : l - 1
934+
@inbounds minorAij[ki, li] = A[k, l]
935+
end
936+
end
937+
@inbounds cofA[i, j] = (-1)^(i - 1 + j - 1) * det(minorAij)
938+
minorAij .= zero(eltype(A))
939+
end
940+
return cofA
941+
end
942+
# partial derivative of the determinant is the matrix of
943+
# cofactors
944+
EnzymeRules.@easy_rule(LinearAlgebra.det(A), (cofactor(A),))
930945

931946
function EnzymeRules.forward(
932947
config::EnzymeRules.FwdConfig,

test/rules/internal_rules.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Enzyme
22
using EnzymeTestUtils
3-
import Random
3+
import Random, LinearAlgebra
44
using Test
55

66
struct TPair
@@ -207,3 +207,25 @@ end
207207
end
208208
end
209209
end
210+
211+
@testset "(matrix) det" begin
212+
@testset "forward" begin
213+
@testset for RT in (Const,DuplicatedNoNeed,Duplicated,),
214+
Tx in (Const,Duplicated,)
215+
xr = [4.0 3.0; 2.0 1.0]
216+
test_forward(LinearAlgebra.det, RT, (xr, Tx))
217+
218+
xc = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
219+
test_forward(LinearAlgebra.det, RT, (xc, Tx))
220+
end
221+
end
222+
@testset "reverse" begin
223+
@testset for RT in (Const, Active,), Tx in (Const, Duplicated,),
224+
x = [4.0 3.0; 2.0 1.0]
225+
test_reverse(LinearAlgebra.det, RT, (x, Tx))
226+
227+
x = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
228+
test_reverse(LinearAlgebra.det, RT, (x, Tx))
229+
end
230+
end
231+
end

0 commit comments

Comments
 (0)