Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/Ariadne.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,34 @@ using Krylov
using LinearAlgebra, SparseArrays
using Enzyme

"""
create_shadow(x)

Allocate a tangent or adjoint value for `x`. The value must be initialized to zero, of the same type as `x`,
and have the same structure (e.g., shape for arrays). By default `Enzyme.make_zero` is used.
"""
function create_shadow(x)
return Enzyme.make_zero(x)
end

"""
zero_shadow!(x)

Set the tangent or adjoint value `x` to zero. By default `Enzyme.remake_zero!` is used.
"""
function zero_shadow!(x)
Enzyme.remake_zero!(x)
return nothing
end

##
# JacobianOperator
##
import LinearAlgebra: mul!

function init_cache(x)
if !Enzyme.Compiler.guaranteed_const(typeof(x))
Enzyme.make_zero(x)
create_shadow(x)
else
return nothing
end
Expand All @@ -23,7 +43,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, T}) where {T}
if x′ === nothing
return Const(x)
else
Enzyme.remake_zero!(x′)
zero_shadow!(x′)
return Duplicated(x, x′)
end
end
Expand Down Expand Up @@ -104,10 +124,9 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
return nothing
end


function init_cache(x, ::Val{N}) where {N}
if !Enzyme.Compiler.guaranteed_const(typeof(x))
return ntuple(_ -> Enzyme.make_zero(x), Val(N))
return ntuple(_ -> create_shadow(x), Val(N))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create an anonymous function if you can use

Suggested change
return ntuple(_ -> create_shadow(x), Val(N))
return ntuple(create_shadow, Val(N))

directly?

else
return nothing
end
Expand All @@ -117,7 +136,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, NTuple{N, T}}, ::Val{N}) wh
if x′ === nothing
return Const(x)
else
Enzyme.remake_zero!(x′)
zero_shadow!(x′)
return BatchDuplicated(x, x′)
end
end
Expand Down
Loading