Skip to content

Commit c7d3884

Browse files
authored
Refactor generated functions to use an empty CodeInfo. (#2663)
1 parent 0d9ed77 commit c7d3884

File tree

5 files changed

+157
-300
lines changed

5 files changed

+157
-300
lines changed

src/analyses/activity.jl

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -451,48 +451,20 @@ Base.@nospecializeinfer @inline function active_reg(@nospecialize(ST::Type), wor
451451
return result
452452
end
453453

454+
function active_reg_nothrow end
455+
454456
function active_reg_nothrow_generator(world::UInt, source::Union{Method, LineNumberNode}, T, self, _)
455457
@nospecialize
456458
result = active_reg(T, world)
457-
458-
# create an empty CodeInfo to return the result
459-
ci = ccall(:jl_new_code_info_uninit, Ref{Core.CodeInfo}, ())
460-
461-
@static if isdefined(Core, :DebugInfo)
462-
# TODO: Add proper debug info
463-
ci.debuginfo = Core.DebugInfo(:none)
464-
else
465-
ci.codelocs = Int32[]
466-
ci.linetable = [
467-
Core.Compiler.LineInfoNode(@__MODULE__, :active_reg_nothrow, source.file, Int32(source.line), Int32(0))
468-
]
469-
end
470459
check_activity_cache_invalidations(world)
471-
ci.min_world = world
472-
ci.max_world = typemax(UInt)
473460

474-
edges = Any[]
475-
# Create the edge for the "query"
476-
# TODO: Check if we can use `Tuple{typeof(EnzymeRules.inactive_type), T}` directly
461+
slotnames = Core.svec(Symbol("#self#"), :T)
462+
code = Any[Core.Compiler.ReturnNode(result)]
463+
ci = create_fresh_codeinfo(active_reg_nothrow, source, world, slotnames, code)
464+
465+
ci.edges = Any[]
477466
inactive_type_sig = Tuple{typeof(EnzymeRules.inactive_type), Type}
478-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), inactive_type_sig)::Core.MethodTable)
479-
push!(edges, inactive_type_sig)
480-
481-
ci.edges = edges
482-
483-
# prepare the slots
484-
ci.slotnames = Symbol[Symbol("#self#"), :t]
485-
ci.slotflags = UInt8[0x00 for i = 1:2]
486-
487-
# return the result
488-
ci.code = Any[Core.Compiler.ReturnNode(result)]
489-
ci.ssaflags = UInt32[0x00] # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
490-
@static if isdefined(Core, :DebugInfo)
491-
else
492-
push!(ci.codelocs, 1)
493-
end
494-
495-
ci.ssavaluetypes = 1
467+
add_edge!(ci.edges, inactive_type_sig)
496468

497469
return ci
498470
end

src/compiler.jl

Lines changed: 52 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ import Enzyme:
2727
FnTypeInfo,
2828
Logic,
2929
allocatedinline,
30-
ismutabletype
30+
ismutabletype,
31+
create_fresh_codeinfo,
32+
add_edge!
3133
using Enzyme
3234

3335
import EnzymeCore
@@ -6227,88 +6229,37 @@ end
62276229
end
62286230
end
62296231

6232+
function thunk end
6233+
62306234
function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fakeworld), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(abi::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type))
62316235
@nospecialize
62326236

6233-
parmnames = (:fakeworld, :fa, :a, :tt, :mode, :width, :modifiedbetween, :returnprimal, :shadowinit, :abi, :erriffuncwritten, :runtimeactivity, :strongzero)
6234-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, parmnames...), Core.svec())
6237+
slotnames = Core.svec(Symbol("#self#"),
6238+
:fakeworld, :fa, :a, :tt, :mode, :width,
6239+
:modifiedbetween, :returnprimal, :shadowinit,
6240+
:abi, :erriffuncwritten, :runtimeactivity, :strongzero)
6241+
stub = Core.GeneratedFunctionStub(thunk, slotnames, Core.svec())
62356242

62366243
ft = eltype(FA)
62376244
primal_tt = Tuple{map(eltype, TT.parameters)...}
62386245
# look up the method match
6239-
method_error = :(throw(MethodError($ft, $primal_tt, $world)))
62406246

62416247
min_world = Ref{UInt}(typemin(UInt))
62426248
max_world = Ref{UInt}(typemax(UInt))
62436249

62446250
mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)
62456251

6246-
mi === nothing && return stub(world, source, method_error)
6252+
mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world))))
62476253

62486254
check_activity_cache_invalidations(world)
62496255

6250-
min_world2 = Ref{UInt}(typemin(UInt))
6251-
max_world2 = Ref{UInt}(typemax(UInt))
6252-
6253-
mi2 = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(Base.identity), Tuple{Nothing}, world, min_world2, max_world2)
6254-
6255-
ci = Core.Compiler.retrieve_code_info(mi2, world)::Core.Compiler.CodeInfo
6256-
6257-
# prepare a new code info
6258-
new_ci = copy(ci)
6259-
empty!(new_ci.code)
6260-
@static if isdefined(Core, :DebugInfo)
6261-
new_ci.debuginfo = Core.DebugInfo(:none)
6262-
else
6263-
empty!(new_ci.codelocs)
6264-
resize!(new_ci.linetable, 1) # see note below
6265-
end
6266-
empty!(new_ci.ssaflags)
6267-
new_ci.ssavaluetypes = 0
6268-
# new_ci.min_world = min_world[]
6269-
new_ci.min_world = world
6270-
new_ci.max_world = max_world[]
6271-
6272-
edges = Any[mi]
6273-
6274-
if Mode == API.DEM_ForwardMode
6275-
fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
6276-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable)
6277-
push!(edges, fwd_sig)
6278-
else
6279-
rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
6280-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable)
6281-
push!(edges, rev_sig)
6282-
6283-
rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}}
6284-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable)
6285-
push!(edges, rev_sig)
6286-
end
6287-
6288-
ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}
6289-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable)
6290-
push!(edges, ina_sig)
6291-
6292-
for gen_sig in (
6293-
Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
6294-
Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
6295-
Tuple{typeof(EnzymeRules.inactive_type), Type},
6296-
)
6297-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable)
6298-
push!(edges, gen_sig)
6299-
end
6300-
6301-
new_ci.edges = edges
6302-
6303-
# XXX: setting this edge does not give us proper method invalidation, see
6304-
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
6305-
# invoking `code_llvm` also does the necessary codegen, as does calling the
6306-
# underlying C methods -- which GPUCompiler does, so everything Just Works.
6256+
edges = Any[]
6257+
add_edge!(edges, mi)
63076258

63086259
ts_ctx = JuliaContext()
63096260
ctx = context(ts_ctx)
63106261
activate(ctx)
6311-
res = try
6262+
result = try
63126263
thunkbase(
63136264
mi,
63146265
world,
@@ -6331,25 +6282,35 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no
63316282
dispose(ts_ctx)
63326283
end
63336284

6334-
# prepare the slots
6335-
new_ci.slotnames = Symbol[Symbol("#self#"), parmnames...]
6336-
new_ci.slotflags = UInt8[0x00 for i = 1:length(new_ci.slotnames)]
6285+
code = Any[Core.Compiler.ReturnNode(result)]
6286+
ci = create_fresh_codeinfo(thunk, source, world, slotnames, code)
63376287

6338-
# return the codegen world age
6339-
push!(new_ci.code, Core.Compiler.ReturnNode(res))
6340-
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
6341-
@static if isdefined(Core, :DebugInfo)
6288+
6289+
6290+
if Mode == API.DEM_ForwardMode
6291+
fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
6292+
add_edge!(edges, fwd_sig)
63426293
else
6343-
push!(new_ci.codelocs, 1) # see note below
6294+
rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
6295+
add_edge!(edges, rev_sig)
6296+
6297+
rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}}
6298+
add_edge!(edges, rev_sig)
6299+
end
6300+
6301+
ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}
6302+
add_edge!(edges, ina_sig)
6303+
6304+
for gen_sig in (
6305+
Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
6306+
Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
6307+
Tuple{typeof(EnzymeRules.inactive_type), Type},
6308+
)
6309+
add_edge!(edges, gen_sig)
63446310
end
6345-
new_ci.ssavaluetypes += 1
6346-
6347-
# NOTE: we keep the first entry of the original linetable, and use it for location info
6348-
# on the call to check_cache. we can't not have a codeloc (using 0 causes
6349-
# corruption of the back trace), and reusing the target function's info
6350-
# has as advantage that we see the name of the kernel in the backtraces.
63516311

6352-
return new_ci
6312+
ci.edges = edges
6313+
return ci
63536314
end
63546315

63556316
@eval @inline function thunk(
@@ -6386,48 +6347,30 @@ end
63866347

63876348
import GPUCompiler: deferred_codegen_jobs
63886349

6350+
function deferred_id_codegen end
6351+
63896352
function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode}, @nospecialize(FA::Type), @nospecialize(A::Type), @nospecialize(TT::Type), Mode::Enzyme.API.CDerivativeMode, Width::Int, @nospecialize(ModifiedBetween::(NTuple{N, Bool} where N)), ReturnPrimal::Bool, ShadowInit::Bool, @nospecialize(ExpectedTapeType::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, @nospecialize(self), @nospecialize(fa::Type), @nospecialize(a::Type), @nospecialize(tt::Type), @nospecialize(mode::Type), @nospecialize(width::Type), @nospecialize(modifiedbetween::Type), @nospecialize(returnprimal::Type), @nospecialize(shadowinit::Type), @nospecialize(expectedtapetype::Type), @nospecialize(erriffuncwritten::Type), @nospecialize(runtimeactivity::Type), @nospecialize(strongzero::Type))
63906353
@nospecialize
63916354

6392-
parmnames = (:fa, :a, :tt, :mode, :width, :modifiedbetween, :returnprimal, :shadowinit, :expectedtapetype, :erriffuncwritten, :runtimeactivity, :strongzero)
6393-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, parmnames...), Core.svec())
6355+
slotnames = Core.svec(Symbol("#self#"),
6356+
:fa, :a, :tt, :mode, :width, :modifiedbetween,
6357+
:returnprimal, :shadowinit, :expectedtapetype,
6358+
:erriffuncwritten, :runtimeactivity, :strongzero)
6359+
6360+
stub = Core.GeneratedFunctionStub(deferred_id_generator, slotnames, Core.svec())
63946361

63956362
ft = eltype(FA)
63966363
primal_tt = Tuple{map(eltype, TT.parameters)...}
63976364
# look up the method match
6398-
method_error = :(throw(MethodError($ft, $primal_tt, $world)))
63996365

64006366
min_world = Ref{UInt}(typemin(UInt))
64016367
max_world = Ref{UInt}(typemax(UInt))
64026368

64036369
mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)
64046370

6405-
mi === nothing && return stub(world, source, method_error)
6371+
mi === nothing && return stub(world, source, :(throw(MethodError($ft, $primal_tt, $world))))
64066372

6407-
ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo
6408-
6409-
# prepare a new code info
6410-
new_ci = copy(ci)
6411-
empty!(new_ci.code)
6412-
@static if isdefined(Core, :DebugInfo)
6413-
new_ci.debuginfo = Core.DebugInfo(:none)
6414-
else
6415-
empty!(new_ci.codelocs)
6416-
resize!(new_ci.linetable, 1) # see note below
6417-
end
6418-
empty!(new_ci.ssaflags)
6419-
new_ci.ssavaluetypes = 0
6420-
# new_ci.min_world = min_world[]
6421-
new_ci.min_world = world
6422-
new_ci.max_world = max_world[]
6423-
new_ci.edges = Core.MethodInstance[mi]
6424-
# XXX: setting this edge does not give us proper method invalidation, see
6425-
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
6426-
# invoking `code_llvm` also does the necessary codegen, as does calling the
6427-
# underlying C methods -- which GPUCompiler does, so everything Just Works.
6428-
64296373
target = EnzymeTarget()
6430-
64316374
rt2 = if A isa UnionAll
64326375
rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi)
64336376

@@ -6472,25 +6415,12 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
64726415
id = Base.reinterpret(Int, pointer(addr))
64736416
deferred_codegen_jobs[id] = job
64746417

6475-
# prepare the slots
6476-
new_ci.slotnames = Symbol[Symbol("#self#"), parmnames...]
6477-
new_ci.slotflags = UInt8[0x00 for i = 1:length(new_ci.slotnames)]
6478-
6479-
# return the codegen world age
6480-
push!(new_ci.code, Core.Compiler.ReturnNode(reinterpret(Ptr{Cvoid}, id)))
6481-
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
6482-
@static if isdefined(Core, :DebugInfo)
6483-
else
6484-
push!(new_ci.codelocs, 1) # see note below
6485-
end
6486-
new_ci.ssavaluetypes += 1
6418+
code = Any[Core.Compiler.ReturnNode(reinterpret(Ptr{Cvoid}, id))]
6419+
ci = create_fresh_codeinfo(deferred_id_codegen, source, world, slotnames, code)
64876420

6488-
# NOTE: we keep the first entry of the original linetable, and use it for location info
6489-
# on the call to check_cache. we can't not have a codeloc (using 0 causes
6490-
# corruption of the back trace), and reusing the target function's info
6491-
# has as advantage that we see the name of the kernel in the backtraces.
6421+
ci.edges = Any[mi]
64926422

6493-
return new_ci
6423+
return ci
64946424
end
64956425

64966426
@eval @inline function deferred_id_codegen(

src/compiler/interpreter.jl

Lines changed: 7 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,85 +24,28 @@ else
2424
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
2525
end
2626

27+
function rule_backedge_holder end
28+
2729
function rule_backedge_holder_generator(world::UInt, source, self, ft::Type)
2830
@nospecialize
29-
sig = Tuple{typeof(Base.identity), Int}
30-
min_world = Ref{UInt}(typemin(UInt))
31-
max_world = Ref{UInt}(typemax(UInt))
32-
has_ambig = Ptr{Int32}(C_NULL)
33-
mthds = Base._methods_by_ftype(
34-
sig,
35-
nothing,
36-
-1, #=lim=#
37-
world,
38-
false, #=ambig=#
39-
min_world,
40-
max_world,
41-
has_ambig,
42-
)
43-
mtypes, msp, m = mthds[1]
44-
mi = ccall(
45-
:jl_specializations_get_linfo,
46-
Ref{Core.MethodInstance},
47-
(Any, Any, Any),
48-
m,
49-
mtypes,
50-
msp,
51-
)
52-
ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo
5331

54-
# prepare a new code info
55-
new_ci = copy(ci)
56-
empty!(new_ci.code)
57-
@static if isdefined(Core, :DebugInfo)
58-
new_ci.debuginfo = Core.DebugInfo(:none)
59-
else
60-
empty!(new_ci.codelocs)
61-
resize!(new_ci.linetable, 1) # see note below
62-
end
63-
empty!(new_ci.ssaflags)
64-
new_ci.ssavaluetypes = 0
65-
new_ci.min_world = min_world[]
66-
new_ci.max_world = max_world[]
32+
code = Any[Core.Compiler.ReturnNode(world)]
33+
ci = Core.Compiler.create_fresh_codeinfo(rule_backedge_holder, source, world, Core.svec(Symbol("#self#"), :ft), code)
6734

68-
### TODO: backedge from inactive, augmented_primal, forward, reverse
6935
edges = Any[]
7036

7137
if ft == typeof(EnzymeRules.augmented_primal)
7238
sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}
73-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig))
74-
push!(edges, sig)
7539
elseif ft == typeof(EnzymeRules.forward)
7640
sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}
77-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig))
78-
push!(edges, sig)
7941
else
8042
sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}}
81-
push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig))
82-
push!(edges, sig)
8343
end
44+
add_edge!(edges, sig)
8445

85-
new_ci.edges = edges
86-
87-
# XXX: setting this edge does not give us proper method invalidation, see
88-
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
89-
# invoking `code_llvm` also does the necessary codegen, as does calling the
90-
# underlying C methods -- which GPUCompiler does, so everything Just Works.
91-
92-
# prepare the slots
93-
new_ci.slotnames = Symbol[Symbol("#self#"), :ft]
94-
new_ci.slotflags = UInt8[0x00 for i = 1:2]
95-
96-
# return the codegen world age
97-
push!(new_ci.code, Core.Compiler.ReturnNode(world))
98-
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
99-
@static if isdefined(Core, :DebugInfo)
100-
else
101-
push!(new_ci.codelocs, 1) # see note below
102-
end
103-
new_ci.ssavaluetypes += 1
46+
ci.edges = edges
10447

105-
return new_ci
48+
return ci
10649
end
10750

10851
@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft)

0 commit comments

Comments
 (0)