Skip to content

Commit 1b3477b

Browse files
committed
use a local cache for activity reg
1 parent 162179a commit 1b3477b

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

src/analyses/activity.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
@enum ActivityState begin
2-
AnyState = 0
3-
ActiveState = 1
4-
DupState = 2
5-
MixedState = 3
6-
end
7-
8-
@inline function Base.:|(a1::ActivityState, a2::ActivityState)
9-
ActivityState(Int(a1) | Int(a2))
10-
end
11-
121
@inline element(::Val{T}) where {T} = T
132

143
@inline ptreltype(::Type{Ptr{T}}) where {T} = T
@@ -393,6 +382,14 @@ Base.@nospecializeinfer @inline function active_reg_inner(
393382
return ty
394383
end
395384

385+
function active_reg_cached(ctx::EnzymeContext, @nospecialize(ST::Type); justActive=false, UnionSret = false, AbstractIsMixed = false)
386+
key = (ST, justActive, UnionSret, AbstractIsMixed)
387+
get!(ctx.activity_cache, key) do
388+
set = Base.IdSet{Type}()
389+
active_reg_inner(ST, set, ctx.world, justActive, UnionSret, AbstractIsMixed)
390+
end
391+
end
392+
396393
Base.@nospecializeinfer @inline function active_reg(@nospecialize(ST::Type), world::UInt; justActive=false, UnionSret = false, AbstractIsMixed = false)
397394
set = Base.IdSet{Type}()
398395
return active_reg_inner(ST, set, world, justActive, UnionSret, AbstractIsMixed)

src/compiler.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ import LLVM: Target, TargetMachine
4242
import SparseArrays
4343
using Printf
4444

45+
@enum ActivityState begin
46+
AnyState = 0
47+
ActiveState = 1
48+
DupState = 2
49+
MixedState = 3
50+
end
51+
52+
@inline function Base.:|(a1::ActivityState, a2::ActivityState)
53+
ActivityState(Int(a1) | Int(a2))
54+
end
55+
56+
mutable struct EnzymeContext
57+
world::UInt
58+
activity_cache::Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}
59+
function EnzymeContext(world)
60+
new(world, Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}())
61+
end
62+
end
63+
4564
using Preferences
4665

4766
bitcode_replacement() = parse(Bool, @load_preference("bitcode_replacement", "true"))
@@ -3282,7 +3301,7 @@ function create_abi_wrapper(
32823301
# 3 is index of shadow
32833302
if existed[3] != 0 &&
32843303
sret_union &&
3285-
active_reg(pactualRetType, world; justActive=true, UnionSret=true) == ActiveState
3304+
active_reg_cached(interp.context, pactualRetType; justActive=true, UnionSret=true) == ActiveState
32863305
rewrite_union_returns_as_ref(enzymefn, data[3], world, width)
32873306
end
32883307
returnNum = 0
@@ -4951,7 +4970,7 @@ end
49514970
if params.err_if_func_written
49524971
FT = TT.parameters[1]
49534972
Ty = eltype(FT)
4954-
reg = active_reg(Ty, job.world)
4973+
reg = active_reg_cached(interp.context, Ty)
49554974
if reg == DupState || reg == MixedState
49564975
swiftself = has_swiftself(primalf)
49574976
todo = LLVM.Value[parameters(primalf)[1+swiftself]]
@@ -4975,7 +4994,7 @@ end
49754994
if !mayWriteToMemory(user)
49764995
slegal, foundv, byref = abs_typeof(user)
49774996
if slegal
4978-
reg2 = active_reg(foundv, job.world)
4997+
reg2 = active_reg_cached(interp.context, foundv)
49794998
if reg2 == ActiveState || reg2 == AnyState
49804999
continue
49815000
end
@@ -5003,7 +5022,7 @@ end
50035022
if operands(user)[2] == cur
50045023
slegal, foundv, byref = abs_typeof(operands(user)[1])
50055024
if slegal
5006-
reg2 = active_reg(foundv, job.world)
5025+
reg2 = active_reg_cached(interp.context, foundv)
50075026
if reg2 == AnyState
50085027
continue
50095028
end
@@ -5037,7 +5056,7 @@ end
50375056
if is_readonly(called)
50385057
slegal, foundv, byref = abs_typeof(user)
50395058
if slegal
5040-
reg2 = active_reg(foundv, job.world)
5059+
reg2 = active_reg_cached(interp.context, foundv)
50415060
if reg2 == ActiveState || reg2 == AnyState
50425061
continue
50435062
end
@@ -5055,7 +5074,7 @@ end
50555074
end
50565075
slegal, foundv, byref = abs_typeof(user)
50575076
if slegal
5058-
reg2 = active_reg(foundv, job.world)
5077+
reg2 = active_reg_cached(interp.context, foundv)
50595078
if reg2 == ActiveState || reg2 == AnyState
50605079
continue
50615080
end

src/compiler/interpreter.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
136136
within_autodiff_rewrite::Bool
137137

138138
handler::T
139+
140+
context::Enzyme.Compiler.EnzymeContext
139141
end
140142

141143
const SigCache = Dict{Tuple, Dict{UInt, Base.IdSet{Type}}}()
@@ -247,7 +249,8 @@ function EnzymeInterpreter(
247249
inactive_rules::Bool,
248250
broadcast_rewrite::Bool,
249251
within_autodiff_rewrite::Bool,
250-
handler
252+
handler,
253+
Enzyme.Compiler.EnzymeContext(world)
251254
)
252255
end
253256

@@ -278,7 +281,9 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
278281
inactive_rules = interp.inactive_rules,
279282
broadcast_rewrite = interp.broadcast_rewrite,
280283
within_autodiff_rewrite = interp.within_autodiff_rewrite,
281-
handler = interp.handler)
284+
handler = interp.handler,
285+
context = interp.context,)
286+
@assert context.world == world
282287
return EnzymeInterpreter(
283288
cache_or_token,
284289
mt,
@@ -291,7 +296,8 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
291296
inactive_rules,
292297
broadcast_rewrite,
293298
within_autodiff_rewrite,
294-
handler
299+
handler,
300+
context
295301
)
296302
end
297303

0 commit comments

Comments
 (0)