Skip to content

Commit 136e165

Browse files
authored
Deepcopy of struct (#2510)
* Deepcopy of struct * fix * Update runtests.jl
1 parent c7636ff commit 136e165

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

src/internal_rules.jl

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -160,36 +160,40 @@ end
160160
@inline function deepcopy_rtact(
161161
copied::RT,
162162
primal::RT,
163-
seen::IdDict,
164-
shadow::RT,
165-
) where {RT<:Union{Integer,Char}}
166-
return Base.deepcopy_internal(shadow, seen)
167-
end
168-
@inline function deepcopy_rtact(
169-
copied::RT,
170-
primal::RT,
171-
seen::IdDict,
172-
shadow::RT,
173-
) where {RT<:AbstractFloat}
174-
return Base.deepcopy_internal(shadow, seen)
175-
end
176-
@inline function deepcopy_rtact(
177-
copied::RT,
178-
primal::RT,
179-
seen::IdDict,
163+
seen::Union{IdDict,Nothing},
180164
shadow::RT,
181-
) where {RT<:Array}
182-
if !haskey(seen, shadow)
165+
)::RT where RT
166+
rt = Enzyme.Compiler.active_reg_inner(RT, (), nothing)
167+
if rt == Enzyme.Compiler.ActiveState || rt == Enzyme.Compiler.AnyState
168+
if seen === nothing
169+
return Base.deepcopy(shadow)
170+
else
171+
return Base.deepcopy_internal(shadow, seen)
172+
end
173+
else
174+
if seen !== nothing && haskey(seen, shadow)
175+
return seen[shadow]
176+
end
183177
if primal === shadow
184-
return seen[shadow] = copied
178+
if seen !== nothing
179+
seen[shadow] = copied
180+
end
181+
return copied
185182
end
186-
newa = RT(undef, size(shadow))
187-
seen[shadow] = newa
188-
for i in eachindex(shadow)
189-
@inbounds newa[i] = deepcopy_rtact(copied[i], primal[i], seen, shadow[i])
183+
184+
if RT <: Array
185+
newa = similar(primal, size(shadow))
186+
if seen === nothing
187+
seen = IdDict()
188+
end
189+
seen[shadow] = newa
190+
for i in eachindex(shadow)
191+
@inbounds newa[i] = deepcopy_rtact(copied[i], primal[i], seen, shadow[i])
192+
end
193+
return newa
190194
end
195+
throw(AssertionError("Unimplemented deepcopy with runtime activity for type $RT"))
191196
end
192-
return seen[shadow]
193197
end
194198

195199
function EnzymeRules.forward(
@@ -199,7 +203,7 @@ function EnzymeRules.forward(
199203
x::Duplicated,
200204
)
201205
primal = func.val(x.val)
202-
return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval))
206+
return Duplicated(primal, deepcopy_rtact(primal, x.val, nothing, x.dval))
203207
end
204208

205209
function EnzymeRules.forward(
@@ -210,7 +214,7 @@ function EnzymeRules.forward(
210214
) where {T,N}
211215
primal = func.val(x.val)
212216
return BatchDuplicated(primal, ntuple(Val(N)) do i
213-
deepcopy_rtact(primal, x.val, IdDict(), x.dval[i])
217+
deepcopy_rtact(primal, x.val, nothing, x.dval[i])
214218
end)
215219
end
216220

@@ -226,15 +230,14 @@ function EnzymeRules.augmented_primal(
226230
nothing
227231
end
228232

229-
@assert !(typeof(x) <: Active)
230-
231233
source = if EnzymeRules.needs_primal(config)
232234
primal
233235
else
234236
x.val
235237
end
236238

237239
shadow = if EnzymeRules.needs_shadow(config)
240+
@assert !(x isa Active)
238241
if EnzymeRules.width(config) == 1
239242
Enzyme.make_zero(
240243
source,
@@ -304,6 +307,8 @@ function EnzymeRules.reverse(
304307
shadow,
305308
x::Annotation{Ty},
306309
) where {RT,Ty}
310+
@assert !(x isa Active)
311+
307312
if EnzymeRules.needs_shadow(config)
308313
if EnzymeRules.width(config) == 1
309314
accumulate_into(x.dval, IdDict(), shadow)
@@ -317,6 +322,17 @@ function EnzymeRules.reverse(
317322
return (nothing,)
318323
end
319324

325+
function EnzymeRules.reverse(
326+
config::EnzymeRules.RevConfig,
327+
func::Const{typeof(Base.deepcopy)},
328+
dret::Active,
329+
shadow,
330+
x::Annotation,
331+
)
332+
return (dret.val,)
333+
end
334+
335+
320336
@inline function pmap_fwd(
321337
idx,
322338
tapes::Vector,

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,19 @@ make3() = (1.0, 2.0, 3.0)
443443

444444
end
445445

446+
function named_deepcopy(x, nt)
447+
nt2 = deepcopy(nt)
448+
return nt2.a + x[1]
449+
end
450+
451+
@testset "Deepcopy" begin
452+
nt = (a = 0.0,)
453+
x = [0.5]
454+
455+
@test Enzyme.gradient(Forward, named_deepcopy, x, Const(nt))[1] [1.0]
456+
@test Enzyme.gradient(Reverse, named_deepcopy, x, Const(nt))[1] [1.0]
457+
end
458+
446459
@testset "Deferred and deferred thunk" begin
447460
function dot(A)
448461
return A[1] * A[1] + A[2] * A[2]

0 commit comments

Comments
 (0)