33using Core. Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
44using Base: _methods_by_ftype
55
6- # generated function that crafts a custom code info to call the actual compiler.
7- # this gives us the flexibility to insert manual back edges for automatic recompilation.
6+ # generated function that returns the world age of a compilation job. this can be used to
7+ # drive compilation, e.g. by using it as a key for a cache, as the age will change when a
8+ # function or any called function is redefined.
9+
10+
11+ """
12+ get_world(ft, tt)
13+
14+ A special function that returns the world age in which the current definition of function
15+ type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
16+ compilation results:
17+
18+ compilation_cache = Dict()
19+ function cache_compilation(ft, tt)
20+ world = get_world(ft, tt)
21+ get!(compilation_cache, (ft, tt, world)) do
22+ # compile
23+ end
24+ end
25+
26+ What makes this function special is that it is a generated function, returning a constant,
27+ whose result is automatically invalidated when the function `ft` (or any called function) is
28+ redefined. This makes this query ideally suited for hot code, where you want to avoid a
29+ costly look-up of the current world age on every invocation.
30+
31+ Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
32+
33+ !!! warning
34+
35+ Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
36+ guaranteed to be correctly invalidated when the target function `ft` is executed or
37+ processed by codegen (e.g., by calling `code_llvm`).
38+ """
39+ get_world
40+
41+ # generate functions currently do not know which world they are invoked for, so we fall
42+ # back to using the current world. this may be wrong when the generator is invoked in a
43+ # different world (TODO : when does this happen?)
844#
9- # we also increment a global specialization counter and pass it along to index the cache.
10-
11- const specialization_counter = Ref{UInt}(0 )
12- @generated function specialization_id(job:: CompilerJob{<:Any,<:Any,FunctionSpec{f,tt}} ) where {f,tt}
13- # get a hold of the method and code info of the kernel function
14- sig = Tuple{f, tt. parameters... }
15- # XXX : instead of typemax(UInt) we should use the world-age of the fspec
16- mthds = _methods_by_ftype(sig, - 1 , typemax(UInt))
45+ # XXX : this should be fixed by JuliaLang/julia#48611
46+
47+ function get_world_generator(self, :: Type{Type{ft}} , :: Type{Type{tt}} ) where {ft, tt}
48+ @nospecialize
49+
50+ # look up the method
51+ sig = Tuple{ft, tt. parameters... }
52+ min_world = Ref{UInt}(typemin(UInt))
53+ max_world = Ref{UInt}(typemax(UInt))
54+ has_ambig = Ptr{Int32}(C_NULL ) # don't care about ambiguous results
55+ mthds = if VERSION >= v" 1.7.0-DEV.1297"
56+ Base. _methods_by_ftype(sig, #= mt=# nothing , #= lim=# - 1 ,
57+ #= world=# typemax(UInt), #= ambig=# false ,
58+ min_world, max_world, has_ambig)
59+ # XXX : use the correct method table to support overlaying kernels
60+ else
61+ Base. _methods_by_ftype(sig, #= lim=# - 1 ,
62+ #= world=# typemax(UInt), #= ambig=# false ,
63+ min_world, max_world, has_ambig)
64+ end
65+ # XXX : using world=-1 is wrong, but the current world isn't exposed to this generator
66+
67+ # check the validity of the method matches
68+ method_error = :(throw(MethodError(ft, tt)))
69+ mthds === nothing && return method_error
1770 Base. isdispatchtuple(tt) || return (:(error(" $tt is not a dispatch tuple" )))
18- length(mthds) == 1 || return (:(throw(MethodError(job. source. f,job. source. tt))))
71+ length(mthds) == 1 || return method_error
72+
73+ # look up the method and code instance
1974 mtypes, msp, m = mthds[1 ]
2075 mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
2176 ci = retrieve_code_info(mi):: CodeInfo
2277
23- # generate a unique id to represent this specialization
24- # TODO : just use the lower world age bound in which this code info is valid.
25- # (the method instance doesn't change when called functions are changed).
26- # but how to get that? the ci here always has min/max world 1/-1.
27- # XXX : don't use `objectid(ci)` here, apparently it can alias (or the CI doesn't change?)
28- id = (specialization_counter[] += 1 )
78+ # XXX : we don't know the world age that this generator was requested to run in, so use
79+ # the current world (we cannot use the mi's world because that doesn't update when
80+ # called functions are changed). this isn't correct, but should be close.
81+ world = Base. get_world_counter()
2982
3083 # prepare a new code info
3184 new_ci = copy(ci)
@@ -34,22 +87,20 @@ const specialization_counter = Ref{UInt}(0)
3487 resize!(new_ci. linetable, 1 ) # see note below
3588 empty!(new_ci. ssaflags)
3689 new_ci. ssavaluetypes = 0
90+ new_ci. min_world = min_world[]
91+ new_ci. max_world = max_world[]
3792 new_ci. edges = MethodInstance[mi]
3893 # XXX : setting this edge does not give us proper method invalidation, see
3994 # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
4095 # invoking `code_llvm` also does the necessary codegen, as does calling the
4196 # underlying C methods -- which GPUCompiler does, so everything Just Works.
4297
4398 # prepare the slots
44- new_ci. slotnames = Symbol[Symbol(" #self#" ), :cache, :job, :compiler, :linker]
45- new_ci. slotflags = UInt8[0x00 for i = 1 : 5 ]
46- cache = SlotNumber(2 )
47- job = SlotNumber(3 )
48- compiler = SlotNumber(4 )
49- linker = SlotNumber(5 )
50-
51- # call the compiler
52- push!(new_ci. code, ReturnNode(id))
99+ new_ci. slotnames = Symbol[Symbol(" #self#" ), :ft, :tt]
100+ new_ci. slotflags = UInt8[0x00 for i = 1 : 3 ]
101+
102+ # return the world
103+ push!(new_ci. code, ReturnNode(world))
53104 push!(new_ci. ssaflags, 0x00 ) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
54105 push!(new_ci. codelocs, 1 ) # see note below
55106 new_ci. ssavaluetypes += 1
@@ -62,17 +113,48 @@ const specialization_counter = Ref{UInt}(0)
62113 return new_ci
63114end
64115
116+ @eval function get_world(ft, tt)
117+ $ (Expr(:meta, :generated_only))
118+ $ (Expr(:meta,
119+ :generated,
120+ Expr(:new,
121+ Core. GeneratedFunctionStub,
122+ :get_world_generator,
123+ Any[:get_world, :ft, :tt],
124+ Any[],
125+ @__LINE__,
126+ QuoteNode(Symbol(@__FILE__)),
127+ true )))
128+ end
129+
65130const cache_lock = ReentrantLock()
131+
132+ """
133+ cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)
134+
135+ Compile `job` using `compiler` and `linker`, and store the result in `cache`.
136+
137+ The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
138+ whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
139+ and return data that can be cached across sessions (e.g., LLVM IR). This data is then
140+ forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
141+ session-dependent objects (e.g., a `CuModule`).
142+ """
66143function cached_compilation(cache:: AbstractDict ,
67144 @nospecialize(job:: CompilerJob ),
68145 compiler:: Function , linker:: Function )
69- # XXX : CompilerJob contains a world age, so can't be respecialized.
70- # have specialization_id take a f/tt and return a world to construct a CompilerJob?
71- key = hash(job, specialization_id(job))
72- force_compilation = compile_hook[] != = nothing
146+ # NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
147+ # using a world age instead of intersecting world age ranges, because we expect
148+ # that the world age is aquired through calling `get_world` and thus will only
149+ # ever change when the kernel function is redefined.
150+ #
151+ # if we ever want to be able to index the cache using a compilation job that
152+ # contains a more recent world age, yet still return an older cached object that
153+ # would still be valid, we'd need the cache to store world ranges instead and
154+ # use an invalidation callback to add upper bounds to entries.
155+ key = hash(job)
73156
74- # XXX : by taking the hash, we index the compilation cache directly with the world age.
75- # that's wrong; we should perform an intersection with the entry its bounds.
157+ force_compilation = compile_hook[] != = nothing
76158
77159 # NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
78160 lock(cache_lock)
0 commit comments