Skip to content

Commit e099b56

Browse files
vchuravylcw
andauthored
Implement support for finalizers (#2736)
Co-authored-by: Lucas Wilcox <lucas@swirlee.com>
1 parent 9950e99 commit e099b56

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

docs/src/faq.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,4 +744,38 @@ autodiff(Forward, f_val, Duplicated(Val(1.0), Val(1.0)))
744744
ERROR: Type of ghost or constant type Duplicated{Val{1.0}} is marked as differentiable.
745745
```
746746

747+
## Finalizers
748+
749+
Julia supports attaching finalizers to objects (see the listing below for an example)
750+
751+
```julia
752+
mutable struct Obj
753+
x::Float64
754+
function Obj(x)
755+
o = new(x)
756+
finalizer(o) do o
757+
# do someting with o
758+
end
759+
return o
760+
end
761+
end
762+
```
763+
764+
When Enzyme encounters a code like:
765+
766+
```julia
767+
function f(x)
768+
o = Obj(x)
769+
# computations over o
770+
return o.x
771+
end
772+
773+
autodiff(Forward, f, Duplicated(1.0, 1.0))
774+
```
775+
776+
Enzyme has to allocate a shadow object for `o` and in the process encounters the finalizer being attached to the primal object.
777+
Now the question is what should Enzyme do with the finalizer for the shadow objects? One option would be to simply ignore it,
778+
but finalizers are often used for resource management (like manually allocating memory) and thus we would leak resources that are attached
779+
to the shadow object. Instead, we define finalizers to be inactive (contain no instructions that are relevant with respect to AD),
780+
yet we must attach them to the shadow object to release resources attached to them.
747781

src/internal_rules.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,3 +1797,44 @@ function EnzymeRules.reverse(
17971797
dxs = map(x -> _hypotreverse(x, w, dret, n), xs)
17981798
return (dx, dy, dz, dxs...)
17991799
end
1800+
1801+
function EnzymeRules.forward(config, ::Const{typeof(Base.finalizer)}, _, f::Const, o)
1802+
f = f.val
1803+
Base.finalizer(f, o.val)
1804+
if EnzymeRules.width(config) == 1
1805+
Base.finalizer(f, o.dval)
1806+
else
1807+
foreach(o.dval) do dv
1808+
Base.finalizer(f, dv)
1809+
end
1810+
end
1811+
1812+
if EnzymeRules.needs_primal(config)
1813+
return o
1814+
else
1815+
return nothing
1816+
end
1817+
end
1818+
1819+
function EnzymeRules.augmented_primal(config, ::Const{typeof(Base.finalizer)}, _, f::Const, o)
1820+
@assert !(o isa Active)
1821+
f = f.val
1822+
Base.finalizer(f, o.val)
1823+
if EnzymeRules.width(config) == 1
1824+
Base.finalizer(f, o.dval)
1825+
else
1826+
foreach(o.dval) do dv
1827+
Base.finalizer(f, dv)
1828+
end
1829+
end
1830+
1831+
primal = EnzymeRules.needs_primal(config) ? o.val : nothing
1832+
shadow = EnzymeRules.needs_shadow(config) ? o.dval : nothing
1833+
1834+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
1835+
end
1836+
1837+
function EnzymeRules.reverse(config, ::Const{typeof(Base.finalizer)}, dret, tape, f::Const, o)
1838+
# No-op
1839+
return (nothing, nothing)
1840+
end

test/finalizers.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using Enzyme
2+
using Test
3+
4+
const FREE_LIST = Vector{Any}()
5+
6+
mutable struct Container
7+
value::Float64
8+
function Container(v::Float64)
9+
c = new(v)
10+
finalizer(c) do c
11+
# Necromance object
12+
push!(FREE_LIST, c)
13+
end
14+
return c
15+
end
16+
end
17+
18+
@noinline function compute(c)
19+
return c.value^2
20+
end
21+
22+
function compute(x::Float64)
23+
c = Container(x)
24+
return compute(c)
25+
end
26+
27+
@testset "primal" begin
28+
x = compute(1.0)
29+
@test x == 1.0
30+
GC.gc()
31+
@test length(FREE_LIST) == 1
32+
empty!(FREE_LIST)
33+
end
34+
35+
@testset "forward" begin
36+
dx, x = autodiff(ForwardWithPrimal, compute, Duplicated(1.0, 2.0))
37+
@test x == 1.0
38+
@test dx == 4.0
39+
GC.gc()
40+
@test length(FREE_LIST) == 2
41+
empty!(FREE_LIST)
42+
43+
dx, = autodiff(Forward, compute, Duplicated(1.0, 2.0))
44+
@test dx == 4.0
45+
GC.gc()
46+
@test length(FREE_LIST) == 2
47+
empty!(FREE_LIST)
48+
end
49+
50+
@testset "batched forward" begin
51+
dx, x = autodiff(ForwardWithPrimal, compute, BatchDuplicated(1.0, (1.0, 2.0)))
52+
@test x == 1.0
53+
@test dx[1] == 2.0
54+
@test dx[2] == 4.0
55+
GC.gc()
56+
@test length(FREE_LIST) == 3
57+
empty!(FREE_LIST)
58+
59+
dx, = autodiff(Forward, compute, BatchDuplicated(1.0, (1.0, 2.0)))
60+
@test dx[1] == 2.0
61+
@test dx[2] == 4.0
62+
GC.gc()
63+
@test length(FREE_LIST) == 3
64+
empty!(FREE_LIST)
65+
end
66+
67+
@testset "reverse" begin
68+
((dx,), x) = autodiff(ReverseWithPrimal, compute, Active(1.0))
69+
@test x == 1.0
70+
@test dx == 2.0
71+
GC.gc()
72+
@test length(FREE_LIST) == 2
73+
empty!(FREE_LIST)
74+
75+
((dx,),) = autodiff(Reverse, compute, Active(1.0))
76+
@test dx == 2.0
77+
GC.gc()
78+
@test length(FREE_LIST) == 2
79+
empty!(FREE_LIST)
80+
end

0 commit comments

Comments
 (0)