Skip to content

Commit aa35128

Browse files
committed
add create_shadow and zero_shadow
1 parent 3eb9d27 commit aa35128

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

src/Ariadne.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,34 @@ using Krylov
66
using LinearAlgebra, SparseArrays
77
using Enzyme
88

9+
"""
10+
create_shadow(x)
11+
12+
Allocate a tangent or adjoint value for `x`. The value must be initialized to zero, of the same type as `x`,
13+
and have the same structure (e.g., shape for arrays). By default `Enzyme.make_zero` is used.
14+
"""
15+
function create_shadow(x)
16+
return Enzyme.make_zero(x)
17+
end
18+
19+
"""
20+
zero_shadow!(x)
21+
22+
Set the tangent or adjoint value `x` to zero. By default `Enzyme.remake_zero!` is used.
23+
"""
24+
function zero_shadow!(x)
25+
Enzyme.remake_zero!(x)
26+
return nothing
27+
end
28+
929
##
1030
# JacobianOperator
1131
##
1232
import LinearAlgebra: mul!
1333

1434
function init_cache(x)
1535
if !Enzyme.Compiler.guaranteed_const(typeof(x))
16-
Enzyme.make_zero(x)
36+
create_shadow(x)
1737
else
1838
return nothing
1939
end
@@ -23,7 +43,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, T}) where {T}
2343
if x′ === nothing
2444
return Const(x)
2545
else
26-
Enzyme.remake_zero!(x′)
46+
zero_shadow!(x′)
2747
return Duplicated(x, x′)
2848
end
2949
end
@@ -104,10 +124,9 @@ function mul!(out, J′::Union{Adjoint{<:Any, <:JacobianOperator}, Transpose{<:A
104124
return nothing
105125
end
106126

107-
108127
function init_cache(x, ::Val{N}) where {N}
109128
if !Enzyme.Compiler.guaranteed_const(typeof(x))
110-
return ntuple(_ -> Enzyme.make_zero(x), Val(N))
129+
return ntuple(_ -> create_shadow(x), Val(N))
111130
else
112131
return nothing
113132
end
@@ -117,7 +136,7 @@ function maybe_duplicated(x::T, x′::Union{Nothing, NTuple{N, T}}, ::Val{N}) wh
117136
if x′ === nothing
118137
return Const(x)
119138
else
120-
Enzyme.remake_zero!(x′)
139+
zero_shadow!(x′)
121140
return BatchDuplicated(x, x′)
122141
end
123142
end

0 commit comments

Comments
 (0)