@@ -42,6 +42,25 @@ import LLVM: Target, TargetMachine
4242import SparseArrays
4343using Printf
4444
45+ @enum ActivityState begin
46+ AnyState = 0
47+ ActiveState = 1
48+ DupState = 2
49+ MixedState = 3
50+ end
51+
52+ @inline function Base.:| (a1:: ActivityState , a2:: ActivityState )
53+ ActivityState (Int (a1) | Int (a2))
54+ end
55+
56+ mutable struct EnzymeContext
57+ world:: UInt
58+ activity_cache:: Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}
59+ function EnzymeContext (world)
60+ new (world, Dict {Tuple{Type,Bool,Bool,Bool},ActivityState} ())
61+ end
62+ end
63+
4564using Preferences
4665
4766bitcode_replacement () = parse (Bool, @load_preference (" bitcode_replacement" , " true" ))
@@ -3282,7 +3301,7 @@ function create_abi_wrapper(
32823301 # 3 is index of shadow
32833302 if existed[3 ] != 0 &&
32843303 sret_union &&
3285- active_reg (pactualRetType, world ; justActive= true , UnionSret= true ) == ActiveState
3304+ active_reg_cached (interp . context, pactualRetType ; justActive= true , UnionSret= true ) == ActiveState
32863305 rewrite_union_returns_as_ref (enzymefn, data[3 ], world, width)
32873306 end
32883307 returnNum = 0
@@ -4951,7 +4970,7 @@ end
49514970 if params. err_if_func_written
49524971 FT = TT. parameters[1 ]
49534972 Ty = eltype (FT)
4954- reg = active_reg (Ty, job . world )
4973+ reg = active_reg_cached (interp . context, Ty )
49554974 if reg == DupState || reg == MixedState
49564975 swiftself = has_swiftself (primalf)
49574976 todo = LLVM. Value[parameters (primalf)[1 + swiftself]]
@@ -4975,7 +4994,7 @@ end
49754994 if ! mayWriteToMemory (user)
49764995 slegal, foundv, byref = abs_typeof (user)
49774996 if slegal
4978- reg2 = active_reg (foundv, job . world )
4997+ reg2 = active_reg_cached (interp . context, foundv )
49794998 if reg2 == ActiveState || reg2 == AnyState
49804999 continue
49815000 end
@@ -5003,7 +5022,7 @@ end
50035022 if operands (user)[2 ] == cur
50045023 slegal, foundv, byref = abs_typeof (operands (user)[1 ])
50055024 if slegal
5006- reg2 = active_reg (foundv, job . world )
5025+ reg2 = active_reg_cached (interp . context, foundv )
50075026 if reg2 == AnyState
50085027 continue
50095028 end
@@ -5037,7 +5056,7 @@ end
50375056 if is_readonly (called)
50385057 slegal, foundv, byref = abs_typeof (user)
50395058 if slegal
5040- reg2 = active_reg (foundv, job . world )
5059+ reg2 = active_reg_cached (interp . context, foundv )
50415060 if reg2 == ActiveState || reg2 == AnyState
50425061 continue
50435062 end
@@ -5055,7 +5074,7 @@ end
50555074 end
50565075 slegal, foundv, byref = abs_typeof (user)
50575076 if slegal
5058- reg2 = active_reg (foundv, job . world )
5077+ reg2 = active_reg_cached (interp . context, foundv )
50595078 if reg2 == ActiveState || reg2 == AnyState
50605079 continue
50615080 end
0 commit comments