Skip to content

Add easy_rule for matrix det#2725

Merged
wsmoses merged 5 commits intoEnzymeAD:mainfrom
kshyatt:ksh/det
Nov 9, 2025
Merged

Add easy_rule for matrix det#2725
wsmoses merged 5 commits intoEnzymeAD:mainfrom
kshyatt:ksh/det

Conversation

@kshyatt
Copy link
Collaborator

@kshyatt kshyatt commented Oct 31, 2025

Something very odd going on with some of this.

  • test_forward fails for the Matrix{ComplexF64} even though it's (secretly) all real valued
  • test_reverse seems to be running the rule on each element of the matrix argument individually @sethaxen any ideas why?

end
end
@inbounds cofA[i, j] = (-1)^(i - 1 + j - 1) * det(minorAij)
minorAij .= zero(eltype(A))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use fill!(minorAij, zero) instead?

@codecov
Copy link

codecov bot commented Oct 31, 2025

Codecov Report

❌ Patch coverage is 30.18868% with 37 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.87%. Comparing base (107b327) to head (f6b5a16).
⚠️ Report is 27 commits behind head on main.

Files with missing lines Patch % Lines
src/rules/customrules.jl 8.00% 23 Missing ⚠️
src/internal_rules.jl 6.66% 14 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2725      +/-   ##
==========================================
- Coverage   72.61%   69.87%   -2.75%     
==========================================
  Files          58       58              
  Lines       18746    19369     +623     
==========================================
- Hits        13613    13534      -79     
- Misses       5133     5835     +702     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wsmoses wsmoses marked this pull request as ready for review November 8, 2025 17:24
@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/EnzymeTestUtils/src/test_utils.jl b/lib/EnzymeTestUtils/src/test_utils.jl
index d111f537..820306bf 100644
--- a/lib/EnzymeTestUtils/src/test_utils.jl
+++ b/lib/EnzymeTestUtils/src/test_utils.jl
@@ -4,7 +4,7 @@ struct CallWithKWargs{KW}
 end
 
 function (c::CallWithKWargs)(f, xs...)
-    f(xs...; c.kwargs...)
+    return f(xs...; c.kwargs...)
 end
 
 struct CallWithCopyKWargs{KW}
@@ -12,9 +12,9 @@ struct CallWithCopyKWargs{KW}
 end
 
 function (c::CallWithCopyKWargs)(f, xs...)
-    deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...)
+    return deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...)
 end
 
 @inline function get_primal(x::Annotation)
-    x.val
-end
\ No newline at end of file
+    return x.val
+end
diff --git a/src/internal_rules.jl b/src/internal_rules.jl
index b359c670..a02d0e79 100644
--- a/src/internal_rules.jl
+++ b/src/internal_rules.jl
@@ -923,7 +923,7 @@ function EnzymeRules.reverse(
 end
 
 function cofactor(A)
-    cofA     = similar(A)
+    cofA = similar(A)
     minorAij = similar(A, size(A, 1) - 1, size(A, 2) - 1)
     for i in 1:size(A, 1), j in 1:size(A, 2)
         fill!(minorAij, zero(eltype(A)))
diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl
index 415affb2..e83d73af 100644
--- a/src/rules/customrules.jl
+++ b/src/rules/customrules.jl
@@ -51,7 +51,7 @@ import LinearAlgebra
         elseif partial <: AbstractVector
             :(LinearAlgebra.dot(adjoint(partial),dx))
         else
-            :(LinearAlgebra.dot(conj(partial),dx))
+            :(LinearAlgebra.dot(conj(partial), dx))
         end
         return quote
             Base.@_inline_meta
@@ -109,27 +109,27 @@ import LinearAlgebra
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Real, dx)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Complex, dx)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::Number)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::Number)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
 end
 
-@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where N
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
+@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where {N}
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
 end
 
-@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where N
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
+@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where {N}
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractVector{<:Complex}, dx::AbstractVector{<:Any})
@@ -137,21 +137,21 @@ end
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Real}, dx::AbstractVector)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Complex}, dx::AbstractVector)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::AbstractArray)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx) + 1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx)
 end
 
 @inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::AbstractArray)
-    pd = Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...))
+    pd = Base.permutedims(partial, (((ndims(dx) + 1):ndims(partial))..., Base.OneTo(ndims(dx))...))
     Base.conj!(pd)
-    EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx)
+    return EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx)
 end
 
 function enzyme_custom_setup_args(
diff --git a/test/rules/internal_rules.jl b/test/rules/internal_rules.jl
index d9a60656..8ae1b8c2 100644
--- a/test/rules/internal_rules.jl
+++ b/test/rules/internal_rules.jl
@@ -210,17 +210,17 @@ end
 
 @testset "(matrix) det" begin
     @testset "forward" begin
-        @testset for RT in (Const,DuplicatedNoNeed,Duplicated,),
-                     Tx in (Const,Duplicated,)
+        @testset for RT in (Const, DuplicatedNoNeed, Duplicated),
+                Tx in (Const, Duplicated)
             xr = [4.0 3.0; 2.0 1.0]
             test_forward(LinearAlgebra.det, RT, (xr, Tx))
 
-            xc = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
+            xc = [4.0 + 0.0im 3.0; 2.0 - 0.0im 1.0]
             test_forward(LinearAlgebra.det, RT, (xc, Tx))
         end
     end
     @testset "reverse" begin
-        @testset for RT in (Const, Active,), Tx in (Const, Duplicated,)
+        @testset for RT in (Const, Active), Tx in (Const, Duplicated)
 
             # TODO see https://github.com/EnzymeAD/Enzyme/issues/2537
             if RT <: Const
@@ -230,7 +230,7 @@ end
             x = [4.0 3.0; 2.0 1.0]
             test_reverse(LinearAlgebra.det, RT, (x, Tx))
 
-            x = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
+            x = [4.0 + 0.0im 3.0; 2.0 - 0.0im 1.0]
             test_reverse(LinearAlgebra.det, RT, (x, Tx))
         end
     end

@wsmoses
Copy link
Member

wsmoses commented Nov 8, 2025

  • test_forward fails for the Matrix{ComplexF64} even though it's (secretly) all real valued

So in this case it was actually testing complex [e.g. the random seeds had imag components], now fixed the complex failures from internal side

  • test_reverse seems to be running the rule on each element of the matrix argument individually @sethaxen any ideas why?

so you were acidentally doing for x in [...] because of a trailing comma

@wsmoses wsmoses merged commit dbadb67 into EnzymeAD:main Nov 9, 2025
42 of 48 checks passed
@kshyatt kshyatt deleted the ksh/det branch November 9, 2025 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants