diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 39f84c88dd..7cc53e8891 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.2.6" +version = "0.2.7" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" diff --git a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl index 56a050455b..7968ea6696 100644 --- a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl +++ b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl @@ -15,6 +15,7 @@ include("test_approx.jl") include("compatible_activities.jl") include("finite_difference_calls.jl") include("generate_tangent.jl") +include("test_utils.jl") include("test_forward.jl") include("test_reverse.jl") diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index 8830ce6784..e39006b823 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -63,15 +63,15 @@ function test_forward( testset_name=nothing, runtime_activity::Bool=false ) - call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...) - call_with_kwargs(f, xs...) = f(xs...; fkwargs...) + call_with_copy = CallWithCopyKWargs(fkwargs) + call_with_kwargs = CallWithKWargs(fkwargs) if testset_name === nothing testset_name = "test_forward: $f with return activity $ret_activity on $(_string_activity(args))" end @testset "$testset_name" begin # format arguments for autodiff and FiniteDifferences activities = map(Base.Fix1(auto_activity, rng), (f, args...)) - primals = map(x -> x.val, activities) + primals = map(get_primal, activities) # call primal, avoid mutating original arguments fcopy = deepcopy(first(primals)) args_copy = deepcopy(Base.tail(primals)) diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index cf904bec3e..590ada1b59 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -58,8 +58,6 @@ end Here we test a rule for a function of an array in batch reverse-mode: - - ```julia x = randn(3) for Tret in (Const, Active), Tx in (Const, BatchDuplicated) @@ -79,14 +77,14 @@ function test_reverse( testset_name=nothing, runtime_activity::Bool=false, output_tangent=nothing) - call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...) + call_with_captured_kwargs = CallWithKWargs(fkwargs) if testset_name === nothing testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))" end @testset "$testset_name" begin # format arguments for autodiff and FiniteDifferences activities = map(Base.Fix1(auto_activity, rng), (f, args...)) - primals = map(x -> x.val, activities) + primals = map(get_primal, activities) # call primal, avoid mutating original arguments fcopy = deepcopy(first(primals)) args_copy = deepcopy(Base.tail(primals)) diff --git a/lib/EnzymeTestUtils/src/test_utils.jl b/lib/EnzymeTestUtils/src/test_utils.jl new file mode 100644 index 0000000000..d111f53799 --- /dev/null +++ b/lib/EnzymeTestUtils/src/test_utils.jl @@ -0,0 +1,20 @@ + +struct CallWithKWargs{KW} + kwargs::KW +end + +function (c::CallWithKWargs)(f, xs...) + f(xs...; c.kwargs...) +end + +struct CallWithCopyKWargs{KW} + kwargs::KW +end + +function (c::CallWithCopyKWargs)(f, xs...) + deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...) +end + +@inline function get_primal(x::Annotation) + x.val +end \ No newline at end of file diff --git a/src/internal_rules.jl b/src/internal_rules.jl index eba792b6a4..b359c67038 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -793,7 +793,7 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, cache) end -# This is required to handle arugments that mix real and complex numbers +# This is required to handle arguments that mix real and complex numbers _project(::Type{<:Real}, x) = x _project(::Type{<:Real}, x::Complex) = real(x) _project(::Type{<:Complex}, x) = x @@ -922,11 +922,27 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing, dα, dβ) end +function cofactor(A) + cofA = similar(A) + minorAij = similar(A, size(A, 1) - 1, size(A, 2) - 1) + for i in 1:size(A, 1), j in 1:size(A, 2) + fill!(minorAij, zero(eltype(A))) + # build minor matrix + for k in 1:size(A, 1), l in 1:size(A, 2) + if !(k == i || l == j) + ki = k < i ? k : k - 1 + li = l < j ? l : l - 1 + @inbounds minorAij[ki, li] = A[k, l] + end + end + @inbounds cofA[i, j] = (-1)^(i - 1 + j - 1) * det(minorAij) + end + return cofA +end - - - +# partial derivative of the determinant is the matrix of cofactors +EnzymeRules.@easy_rule(LinearAlgebra.det(A::AbstractMatrix), (cofactor(A),)) function EnzymeRules.forward( config::EnzymeRules.FwdConfig, diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 2a86748722..415affb226 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -48,8 +48,10 @@ import LinearAlgebra :(LinearAlgebra.dot(partial,dx)) elseif dx <: AbstractFloat || dx <: AbstractArray{<:AbstractFloat} :(LinearAlgebra.dot(dx, partial)) - else + elseif partial <: AbstractVector :(LinearAlgebra.dot(adjoint(partial),dx)) + else + :(LinearAlgebra.dot(conj(partial),dx)) end return quote Base.@_inline_meta @@ -106,10 +108,52 @@ import LinearAlgebra end end -@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial, dx) +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Real, dx) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Complex, dx) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::Number) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::Number) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where N + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where N + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractVector{<:Complex}, dx::AbstractVector{<:Any}) EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx) end +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Real}, dx::AbstractVector) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Complex}, dx::AbstractVector) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::AbstractArray) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx) +end + +@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::AbstractArray) + pd = Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)) + Base.conj!(pd) + EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx) +end + function enzyme_custom_setup_args( @nospecialize(B::Union{Nothing, LLVM.IRBuilder}), orig::LLVM.CallInst, diff --git a/test/rules/internal_rules.jl b/test/rules/internal_rules.jl index 99d292caf2..d9a60656fd 100644 --- a/test/rules/internal_rules.jl +++ b/test/rules/internal_rules.jl @@ -1,6 +1,6 @@ using Enzyme using EnzymeTestUtils -import Random +import Random, LinearAlgebra using Test struct TPair @@ -207,3 +207,31 @@ end end end end + +@testset "(matrix) det" begin + @testset "forward" begin + @testset for RT in (Const,DuplicatedNoNeed,Duplicated,), + Tx in (Const,Duplicated,) + xr = [4.0 3.0; 2.0 1.0] + test_forward(LinearAlgebra.det, RT, (xr, Tx)) + + xc = [4.0+0.0im 3.0; 2.0-0.0im 1.0] + test_forward(LinearAlgebra.det, RT, (xc, Tx)) + end + end + @testset "reverse" begin + @testset for RT in (Const, Active,), Tx in (Const, Duplicated,) + + # TODO see https://github.com/EnzymeAD/Enzyme/issues/2537 + if RT <: Const + continue + end + + x = [4.0 3.0; 2.0 1.0] + test_reverse(LinearAlgebra.det, RT, (x, Tx)) + + x = [4.0+0.0im 3.0; 2.0-0.0im 1.0] + test_reverse(LinearAlgebra.det, RT, (x, Tx)) + end + end +end