Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.2.6"
version = "0.2.7"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
1 change: 1 addition & 0 deletions lib/EnzymeTestUtils/src/EnzymeTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include("test_approx.jl")
include("compatible_activities.jl")
include("finite_difference_calls.jl")
include("generate_tangent.jl")
include("test_utils.jl")
include("test_forward.jl")
include("test_reverse.jl")

Expand Down
6 changes: 3 additions & 3 deletions lib/EnzymeTestUtils/src/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ function test_forward(
testset_name=nothing,
runtime_activity::Bool=false
)
call_with_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
call_with_kwargs(f, xs...) = f(xs...; fkwargs...)
call_with_copy = CallWithCopyKWargs(fkwargs)
call_with_kwargs = CallWithKWargs(fkwargs)
if testset_name === nothing
testset_name = "test_forward: $f with return activity $ret_activity on $(_string_activity(args))"
end
@testset "$testset_name" begin
# format arguments for autodiff and FiniteDifferences
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
primals = map(x -> x.val, activities)
primals = map(get_primal, activities)
# call primal, avoid mutating original arguments
fcopy = deepcopy(first(primals))
args_copy = deepcopy(Base.tail(primals))
Expand Down
6 changes: 2 additions & 4 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ end

Here we test a rule for a function of an array in batch reverse-mode:



```julia
x = randn(3)
for Tret in (Const, Active), Tx in (Const, BatchDuplicated)
Expand All @@ -79,14 +77,14 @@ function test_reverse(
testset_name=nothing,
runtime_activity::Bool=false,
output_tangent=nothing)
call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...)
call_with_captured_kwargs = CallWithKWargs(fkwargs)
if testset_name === nothing
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
end
@testset "$testset_name" begin
# format arguments for autodiff and FiniteDifferences
activities = map(Base.Fix1(auto_activity, rng), (f, args...))
primals = map(x -> x.val, activities)
primals = map(get_primal, activities)
# call primal, avoid mutating original arguments
fcopy = deepcopy(first(primals))
args_copy = deepcopy(Base.tail(primals))
Expand Down
20 changes: 20 additions & 0 deletions lib/EnzymeTestUtils/src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

struct CallWithKWargs{KW}
kwargs::KW
end

function (c::CallWithKWargs)(f, xs...)
f(xs...; c.kwargs...)
end

struct CallWithCopyKWargs{KW}
kwargs::KW
end

function (c::CallWithCopyKWargs)(f, xs...)
deepcopy(f)(deepcopy(xs)...; deepcopy(c.kwargs)...)
end

@inline function get_primal(x::Annotation)
x.val
end
24 changes: 20 additions & 4 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ function EnzymeRules.augmented_primal(
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

# This is required to handle arugments that mix real and complex numbers
# This is required to handle arguments that mix real and complex numbers
_project(::Type{<:Real}, x) = x
_project(::Type{<:Real}, x::Complex) = real(x)
_project(::Type{<:Complex}, x) = x
Expand Down Expand Up @@ -922,11 +922,27 @@ function EnzymeRules.reverse(
return (nothing, nothing, nothing, dα, dβ)
end

function cofactor(A)
cofA = similar(A)
minorAij = similar(A, size(A, 1) - 1, size(A, 2) - 1)
for i in 1:size(A, 1), j in 1:size(A, 2)
fill!(minorAij, zero(eltype(A)))

# build minor matrix
for k in 1:size(A, 1), l in 1:size(A, 2)
if !(k == i || l == j)
ki = k < i ? k : k - 1
li = l < j ? l : l - 1
@inbounds minorAij[ki, li] = A[k, l]
end
end
@inbounds cofA[i, j] = (-1)^(i - 1 + j - 1) * det(minorAij)
end
return cofA
end




# partial derivative of the determinant is the matrix of cofactors
EnzymeRules.@easy_rule(LinearAlgebra.det(A::AbstractMatrix), (cofactor(A),))

function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
Expand Down
48 changes: 46 additions & 2 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ import LinearAlgebra
:(LinearAlgebra.dot(partial,dx))
elseif dx <: AbstractFloat || dx <: AbstractArray{<:AbstractFloat}
:(LinearAlgebra.dot(dx, partial))
else
elseif partial <: AbstractVector
:(LinearAlgebra.dot(adjoint(partial),dx))
else
:(LinearAlgebra.dot(conj(partial),dx))
end
return quote
Base.@_inline_meta
Expand Down Expand Up @@ -106,10 +108,52 @@ import LinearAlgebra
end
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial, dx)
@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Real, dx)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::Complex, dx)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::Number)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::Number)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real, N}, dx::AbstractArray{<:Any, N}) where N
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, partial, dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex, N}, dx::AbstractArray{<:Any, N}) where N
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, conj(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractVector{<:Complex}, dx::AbstractVector{<:Any})
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Real}, dx::AbstractVector)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, transpose(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractMatrix{<:Complex}, dx::AbstractVector)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, adjoint(partial), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Real}, dx::AbstractArray)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...)), dx)
end

@inline function EnzymeCore.EnzymeRules.multiply_rev_into(prev, partial::AbstractArray{<:Complex}, dx::AbstractArray)
pd = Base.permutedims(partial, (((ndims(dx)+1):ndims(partial))..., Base.OneTo(ndims(dx))...))
Base.conj!(pd)
EnzymeCore.EnzymeRules.multiply_fwd_into(prev, pd, dx)
end

function enzyme_custom_setup_args(
@nospecialize(B::Union{Nothing, LLVM.IRBuilder}),
orig::LLVM.CallInst,
Expand Down
30 changes: 29 additions & 1 deletion test/rules/internal_rules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Enzyme
using EnzymeTestUtils
import Random
import Random, LinearAlgebra
using Test

struct TPair
Expand Down Expand Up @@ -207,3 +207,31 @@ end
end
end
end

@testset "(matrix) det" begin
@testset "forward" begin
@testset for RT in (Const,DuplicatedNoNeed,Duplicated,),
Tx in (Const,Duplicated,)
xr = [4.0 3.0; 2.0 1.0]
test_forward(LinearAlgebra.det, RT, (xr, Tx))

xc = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
test_forward(LinearAlgebra.det, RT, (xc, Tx))
end
end
@testset "reverse" begin
@testset for RT in (Const, Active,), Tx in (Const, Duplicated,)

# TODO see https://github.com/EnzymeAD/Enzyme/issues/2537
if RT <: Const
continue
end

x = [4.0 3.0; 2.0 1.0]
test_reverse(LinearAlgebra.det, RT, (x, Tx))

x = [4.0+0.0im 3.0; 2.0-0.0im 1.0]
test_reverse(LinearAlgebra.det, RT, (x, Tx))
end
end
end
Loading