@@ -27,7 +27,9 @@ import Enzyme:
2727 FnTypeInfo,
2828 Logic,
2929 allocatedinline,
30- ismutabletype
30+ ismutabletype,
31+ create_fresh_codeinfo,
32+ add_edge!
3133using Enzyme
3234
3335import EnzymeCore
@@ -6227,88 +6229,37 @@ end
62276229 end
62286230end
62296231
6232+ function thunk end
6233+
62306234function thunk_generator (world:: UInt , source:: Union{Method, LineNumberNode} , @nospecialize (FA:: Type ), @nospecialize (A:: Type ), @nospecialize (TT:: Type ), Mode:: Enzyme.API.CDerivativeMode , Width:: Int , @nospecialize (ModifiedBetween:: (NTuple{N, Bool} where N) ), ReturnPrimal:: Bool , ShadowInit:: Bool , @nospecialize (ABI:: Type ), ErrIfFuncWritten:: Bool , RuntimeActivity:: Bool , StrongZero:: Bool , @nospecialize (self), @nospecialize (fakeworld), @nospecialize (fa:: Type ), @nospecialize (a:: Type ), @nospecialize (tt:: Type ), @nospecialize (mode:: Type ), @nospecialize (width:: Type ), @nospecialize (modifiedbetween:: Type ), @nospecialize (returnprimal:: Type ), @nospecialize (shadowinit:: Type ), @nospecialize (abi:: Type ), @nospecialize (erriffuncwritten:: Type ), @nospecialize (runtimeactivity:: Type ), @nospecialize (strongzero:: Type ))
62316235 @nospecialize
62326236
6233- parmnames = (:fakeworld , :fa , :a , :tt , :mode , :width , :modifiedbetween , :returnprimal , :shadowinit , :abi , :erriffuncwritten , :runtimeactivity , :strongzero )
6234- stub = Core. GeneratedFunctionStub (identity, Core. svec (:methodinstance , parmnames... ), Core. svec ())
6237+ slotnames = Core. svec (Symbol (" #self#" ),
6238+ :fakeworld , :fa , :a , :tt , :mode , :width ,
6239+ :modifiedbetween , :returnprimal , :shadowinit ,
6240+ :abi , :erriffuncwritten , :runtimeactivity , :strongzero )
6241+ stub = Core. GeneratedFunctionStub (thunk, slotnames, Core. svec ())
62356242
62366243 ft = eltype (FA)
62376244 primal_tt = Tuple{map (eltype, TT. parameters)... }
62386245 # look up the method match
6239- method_error = :(throw (MethodError ($ ft, $ primal_tt, $ world)))
62406246
62416247 min_world = Ref {UInt} (typemin (UInt))
62426248 max_world = Ref {UInt} (typemax (UInt))
62436249
62446250 mi = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)
62456251
6246- mi === nothing && return stub (world, source, method_error )
6252+ mi === nothing && return stub (world, source, :( throw ( MethodError ( $ ft, $ primal_tt, $ world))) )
62476253
62486254 check_activity_cache_invalidations (world)
62496255
6250- min_world2 = Ref {UInt} (typemin (UInt))
6251- max_world2 = Ref {UInt} (typemax (UInt))
6252-
6253- mi2 = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (Base. identity), Tuple{Nothing}, world, min_world2, max_world2)
6254-
6255- ci = Core. Compiler. retrieve_code_info (mi2, world):: Core.Compiler.CodeInfo
6256-
6257- # prepare a new code info
6258- new_ci = copy (ci)
6259- empty! (new_ci. code)
6260- @static if isdefined (Core, :DebugInfo )
6261- new_ci. debuginfo = Core. DebugInfo (:none )
6262- else
6263- empty! (new_ci. codelocs)
6264- resize! (new_ci. linetable, 1 ) # see note below
6265- end
6266- empty! (new_ci. ssaflags)
6267- new_ci. ssavaluetypes = 0
6268- # new_ci.min_world = min_world[]
6269- new_ci. min_world = world
6270- new_ci. max_world = max_world[]
6271-
6272- edges = Any[mi]
6273-
6274- if Mode == API. DEM_ForwardMode
6275- fwd_sig = Tuple{typeof (EnzymeRules. forward), <: EnzymeRules.FwdConfig , <: Enzyme.EnzymeCore.Annotation , Type{<: Enzyme.EnzymeCore.Annotation },Vararg{Enzyme. EnzymeCore. Annotation}}
6276- push! (edges, ccall (:jl_method_table_for , Any, (Any,), fwd_sig):: Core.MethodTable )
6277- push! (edges, fwd_sig)
6278- else
6279- rev_sig = Tuple{typeof (EnzymeRules. augmented_primal), <: EnzymeRules.RevConfig , <: Enzyme.EnzymeCore.Annotation , Type{<: Enzyme.EnzymeCore.Annotation },Vararg{Enzyme. EnzymeCore. Annotation}}
6280- push! (edges, ccall (:jl_method_table_for , Any, (Any,), rev_sig):: Core.MethodTable )
6281- push! (edges, rev_sig)
6282-
6283- rev_sig = Tuple{typeof (EnzymeRules. reverse), <: EnzymeRules.RevConfig , <: Enzyme.EnzymeCore.Annotation , Union{Type{<: Enzyme.EnzymeCore.Annotation }, Enzyme. EnzymeCore. Active}, Any, Vararg{Enzyme. EnzymeCore. Annotation}}
6284- push! (edges, ccall (:jl_method_table_for , Any, (Any,), rev_sig):: Core.MethodTable )
6285- push! (edges, rev_sig)
6286- end
6287-
6288- ina_sig = Tuple{typeof (EnzymeRules. inactive), Vararg{Any}}
6289- push! (edges, ccall (:jl_method_table_for , Any, (Any,), ina_sig):: Core.MethodTable )
6290- push! (edges, ina_sig)
6291-
6292- for gen_sig in (
6293- Tuple{typeof (EnzymeRules. inactive_noinl), Vararg{Any}},
6294- Tuple{typeof (EnzymeRules. noalias), Vararg{Any}},
6295- Tuple{typeof (EnzymeRules. inactive_type), Type},
6296- )
6297- push! (edges, ccall (:jl_method_table_for , Any, (Any,), gen_sig):: Core.MethodTable )
6298- push! (edges, gen_sig)
6299- end
6300-
6301- new_ci. edges = edges
6302-
6303- # XXX : setting this edge does not give us proper method invalidation, see
6304- # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
6305- # invoking `code_llvm` also does the necessary codegen, as does calling the
6306- # underlying C methods -- which GPUCompiler does, so everything Just Works.
6256+ edges = Any[]
6257+ add_edge! (edges, mi)
63076258
63086259 ts_ctx = JuliaContext ()
63096260 ctx = context (ts_ctx)
63106261 activate (ctx)
6311- res = try
6262+ result = try
63126263 thunkbase (
63136264 mi,
63146265 world,
@@ -6331,25 +6282,35 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no
63316282 dispose (ts_ctx)
63326283 end
63336284
6334- # prepare the slots
6335- new_ci. slotnames = Symbol[Symbol (" #self#" ), parmnames... ]
6336- new_ci. slotflags = UInt8[0x00 for i = 1 : length (new_ci. slotnames)]
6285+ code = Any[Core. Compiler. ReturnNode (result)]
6286+ ci = create_fresh_codeinfo (thunk, source, world, slotnames, code)
63376287
6338- # return the codegen world age
6339- push! (new_ci. code, Core. Compiler. ReturnNode (res))
6340- push! (new_ci. ssaflags, 0x00 ) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
6341- @static if isdefined (Core, :DebugInfo )
6288+
6289+
6290+ if Mode == API. DEM_ForwardMode
6291+ fwd_sig = Tuple{typeof (EnzymeRules. forward), <: EnzymeRules.FwdConfig , <: Enzyme.EnzymeCore.Annotation , Type{<: Enzyme.EnzymeCore.Annotation },Vararg{Enzyme. EnzymeCore. Annotation}}
6292+ add_edge! (edges, fwd_sig)
63426293 else
6343- push! (new_ci. codelocs, 1 ) # see note below
6294+ rev_sig = Tuple{typeof (EnzymeRules. augmented_primal), <: EnzymeRules.RevConfig , <: Enzyme.EnzymeCore.Annotation , Type{<: Enzyme.EnzymeCore.Annotation },Vararg{Enzyme. EnzymeCore. Annotation}}
6295+ add_edge! (edges, rev_sig)
6296+
6297+ rev_sig = Tuple{typeof (EnzymeRules. reverse), <: EnzymeRules.RevConfig , <: Enzyme.EnzymeCore.Annotation , Union{Type{<: Enzyme.EnzymeCore.Annotation }, Enzyme. EnzymeCore. Active}, Any, Vararg{Enzyme. EnzymeCore. Annotation}}
6298+ add_edge! (edges, rev_sig)
6299+ end
6300+
6301+ ina_sig = Tuple{typeof (EnzymeRules. inactive), Vararg{Any}}
6302+ add_edge! (edges, ina_sig)
6303+
6304+ for gen_sig in (
6305+ Tuple{typeof (EnzymeRules. inactive_noinl), Vararg{Any}},
6306+ Tuple{typeof (EnzymeRules. noalias), Vararg{Any}},
6307+ Tuple{typeof (EnzymeRules. inactive_type), Type},
6308+ )
6309+ add_edge! (edges, gen_sig)
63446310 end
6345- new_ci. ssavaluetypes += 1
6346-
6347- # NOTE: we keep the first entry of the original linetable, and use it for location info
6348- # on the call to check_cache. we can't not have a codeloc (using 0 causes
6349- # corruption of the back trace), and reusing the target function's info
6350- # has as advantage that we see the name of the kernel in the backtraces.
63516311
6352- return new_ci
6312+ ci. edges = edges
6313+ return ci
63536314end
63546315
63556316@eval @inline function thunk (
@@ -6386,48 +6347,30 @@ end
63866347
63876348import GPUCompiler: deferred_codegen_jobs
63886349
6350+ function deferred_id_codegen end
6351+
63896352function deferred_id_generator (world:: UInt , source:: Union{Method, LineNumberNode} , @nospecialize (FA:: Type ), @nospecialize (A:: Type ), @nospecialize (TT:: Type ), Mode:: Enzyme.API.CDerivativeMode , Width:: Int , @nospecialize (ModifiedBetween:: (NTuple{N, Bool} where N) ), ReturnPrimal:: Bool , ShadowInit:: Bool , @nospecialize (ExpectedTapeType:: Type ), ErrIfFuncWritten:: Bool , RuntimeActivity:: Bool , StrongZero:: Bool , @nospecialize (self), @nospecialize (fa:: Type ), @nospecialize (a:: Type ), @nospecialize (tt:: Type ), @nospecialize (mode:: Type ), @nospecialize (width:: Type ), @nospecialize (modifiedbetween:: Type ), @nospecialize (returnprimal:: Type ), @nospecialize (shadowinit:: Type ), @nospecialize (expectedtapetype:: Type ), @nospecialize (erriffuncwritten:: Type ), @nospecialize (runtimeactivity:: Type ), @nospecialize (strongzero:: Type ))
63906353 @nospecialize
63916354
6392- parmnames = (:fa , :a , :tt , :mode , :width , :modifiedbetween , :returnprimal , :shadowinit , :expectedtapetype , :erriffuncwritten , :runtimeactivity , :strongzero )
6393- stub = Core. GeneratedFunctionStub (identity, Core. svec (:methodinstance , parmnames... ), Core. svec ())
6355+ slotnames = Core. svec (Symbol (" #self#" ),
6356+ :fa , :a , :tt , :mode , :width , :modifiedbetween ,
6357+ :returnprimal , :shadowinit , :expectedtapetype ,
6358+ :erriffuncwritten , :runtimeactivity , :strongzero )
6359+
6360+ stub = Core. GeneratedFunctionStub (deferred_id_generator, slotnames, Core. svec ())
63946361
63956362 ft = eltype (FA)
63966363 primal_tt = Tuple{map (eltype, TT. parameters)... }
63976364 # look up the method match
6398- method_error = :(throw (MethodError ($ ft, $ primal_tt, $ world)))
63996365
64006366 min_world = Ref {UInt} (typemin (UInt))
64016367 max_world = Ref {UInt} (typemax (UInt))
64026368
64036369 mi = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)
64046370
6405- mi === nothing && return stub (world, source, method_error )
6371+ mi === nothing && return stub (world, source, :( throw ( MethodError ( $ ft, $ primal_tt, $ world))) )
64066372
6407- ci = Core. Compiler. retrieve_code_info (mi, world):: Core.Compiler.CodeInfo
6408-
6409- # prepare a new code info
6410- new_ci = copy (ci)
6411- empty! (new_ci. code)
6412- @static if isdefined (Core, :DebugInfo )
6413- new_ci. debuginfo = Core. DebugInfo (:none )
6414- else
6415- empty! (new_ci. codelocs)
6416- resize! (new_ci. linetable, 1 ) # see note below
6417- end
6418- empty! (new_ci. ssaflags)
6419- new_ci. ssavaluetypes = 0
6420- # new_ci.min_world = min_world[]
6421- new_ci. min_world = world
6422- new_ci. max_world = max_world[]
6423- new_ci. edges = Core. MethodInstance[mi]
6424- # XXX : setting this edge does not give us proper method invalidation, see
6425- # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
6426- # invoking `code_llvm` also does the necessary codegen, as does calling the
6427- # underlying C methods -- which GPUCompiler does, so everything Just Works.
6428-
64296373 target = EnzymeTarget ()
6430-
64316374 rt2 = if A isa UnionAll
64326375 rrt = primal_return_type_world (Mode == API. DEM_ForwardMode ? Forward : Reverse, world, mi)
64336376
@@ -6472,25 +6415,12 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
64726415 id = Base. reinterpret (Int, pointer (addr))
64736416 deferred_codegen_jobs[id] = job
64746417
6475- # prepare the slots
6476- new_ci. slotnames = Symbol[Symbol (" #self#" ), parmnames... ]
6477- new_ci. slotflags = UInt8[0x00 for i = 1 : length (new_ci. slotnames)]
6478-
6479- # return the codegen world age
6480- push! (new_ci. code, Core. Compiler. ReturnNode (reinterpret (Ptr{Cvoid}, id)))
6481- push! (new_ci. ssaflags, 0x00 ) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
6482- @static if isdefined (Core, :DebugInfo )
6483- else
6484- push! (new_ci. codelocs, 1 ) # see note below
6485- end
6486- new_ci. ssavaluetypes += 1
6418+ code = Any[Core. Compiler. ReturnNode (reinterpret (Ptr{Cvoid}, id))]
6419+ ci = create_fresh_codeinfo (deferred_id_codegen, source, world, slotnames, code)
64876420
6488- # NOTE: we keep the first entry of the original linetable, and use it for location info
6489- # on the call to check_cache. we can't not have a codeloc (using 0 causes
6490- # corruption of the back trace), and reusing the target function's info
6491- # has as advantage that we see the name of the kernel in the backtraces.
6421+ ci. edges = Any[mi]
64926422
6493- return new_ci
6423+ return ci
64946424end
64956425
64966426@eval @inline function deferred_id_codegen (
0 commit comments