diff --git a/src/Ariadne.jl b/src/Ariadne.jl index 65e77d1a..b08ba87b 100644 --- a/src/Ariadne.jl +++ b/src/Ariadne.jl @@ -6,6 +6,26 @@ using Krylov using LinearAlgebra, SparseArrays using Enzyme +""" + create_shadow(x) + +Allocate a tangent or adjoint value for `x`. The value must be initialized to zero, of the same type as `x`, +and have the same structure (e.g., shape for arrays). By default `Enzyme.make_zero` is used. +""" +function create_shadow(x) + return Enzyme.make_zero(x) +end + +""" + zero_shadow!(x) + +Set the tangent or adjoint value `x` to zero. By default `Enzyme.remake_zero!` is used. +""" +function zero_shadow!(x) + Enzyme.remake_zero!(x) + return nothing +end + ## # JacobianOperator ## @@ -13,7 +33,7 @@ import LinearAlgebra: mul! function init_cache(x) if !Enzyme.Compiler.guaranteed_const(typeof(x)) - Enzyme.make_zero(x) + create_shadow(x) else return nothing end @@ -23,7 +43,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, T}) where {T} if x′ === nothing return Const(x) else - Enzyme.remake_zero!(x′) + zero_shadow!(x′) return Duplicated(x, x′) end end @@ -104,10 +124,9 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A return nothing end - function init_cache(x, ::Val{N}) where {N} if !Enzyme.Compiler.guaranteed_const(typeof(x)) - return ntuple(_ -> Enzyme.make_zero(x), Val(N)) + return ntuple(_ -> create_shadow(x), Val(N)) else return nothing end @@ -117,7 +136,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, NTuple{N, T}}, ::Val{N}) wh if x′ === nothing return Const(x) else - Enzyme.remake_zero!(x′) + zero_shadow!(x′) return BatchDuplicated(x, x′) end end