Skip to content

Commit 0269f1d

Browse files
committed
allways use get_interpreter to construct things correctly
1 parent cce7cfe commit 0269f1d

File tree

4 files changed

+14
-93
lines changed

4 files changed

+14
-93
lines changed

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6397,7 +6397,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
63976397

63986398
target = EnzymeTarget()
63996399
rt2 = if A isa UnionAll
6400-
rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi)
6400+
rrt = primal_return_type_world(Mode, world, mi)
64016401

64026402
# Don't error here but default to nothing return since in cuda context we don't use the device overrides
64036403
if rrt == Union{}

src/compiler/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function get_job(
2828
end
2929

3030
primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt, world)
31-
rt = Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt)
31+
rt = Compiler.primal_return_type_world(mode, world, primal)
3232

3333
@assert primal !== nothing
3434
rt = A{rt}

src/errors.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,10 @@ end
4949
using InteractiveUtils
5050

5151
function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool=false, kwargs...)
52-
CT = @static if VERSION >= v"1.11.0-DEV.1552"
53-
EnzymeCacheToken(
54-
typeof(DefaultCompilerTarget()),
55-
false,
56-
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
57-
EnzymeCompilerParams,
58-
world,
59-
mode == API.DEM_ForwardMode,
60-
mode != API.DEM_ForwardMode,
61-
true
62-
)
63-
else
64-
if mode == API.DEM_ForwardMode
65-
GLOBAL_FWD_CACHE
66-
else
67-
GLOBAL_REV_CACHE
68-
end
69-
end
70-
71-
interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
72-
52+
target = Compiler.DefaultCompilerTarget()
53+
params = PrimalCompilerParams(mode)
54+
job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params))
55+
interp = GPUCompiler.get_interpreter(job)
7356
sig = mi.specTypes # XXX: can we just use the method instance?
7457
if interactive
7558
# call Cthulhu without introducing a dependency on Cthulhu

src/typeutils/inference.jl

Lines changed: 8 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,82 +13,20 @@ function return_type(interp::Core.Compiler.AbstractInterpreter, mi::Core.MethodI
1313
end
1414
end
1515

16-
function primal_interp_world(
17-
@nospecialize(::ReverseMode),
18-
world::UInt
19-
)
20-
mode = Enzyme.API.DEM_ReverseModeCombined
21-
22-
CT = @static if VERSION >= v"1.11.0-DEV.1552"
23-
EnzymeCacheToken(
24-
typeof(DefaultCompilerTarget()),
25-
false,
26-
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
27-
EnzymeCompilerParams,
28-
world,
29-
false,
30-
true,
31-
true
32-
)
33-
else
34-
Enzyme.Compiler.GLOBAL_REV_CACHE
35-
end
36-
37-
Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
16+
function primal_interp_world(mode::Enzyme.API.CDerivativeMode, world, mi)
17+
target = Compiler.DefaultCompilerTarget()
18+
params = PrimalCompilerParams(mode)
19+
job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params), world)
20+
return GPUCompiler.get_interpreter(job)
3821
end
3922

40-
function primal_interp_world(
41-
@nospecialize(::ForwardMode),
42-
world::UInt
43-
)
44-
mode = Enzyme.API.DEM_ForwardMode
45-
46-
CT = @static if VERSION >= v"1.11.0-DEV.1552"
47-
EnzymeCacheToken(
48-
typeof(DefaultCompilerTarget()),
49-
false,
50-
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
51-
EnzymeCompilerParams,
52-
world,
53-
true,
54-
false,
55-
true
56-
)
57-
else
58-
Enzyme.Compiler.GLOBAL_FWD_CACHE
59-
end
23+
primal_interp_world(mode::Mode, world, mi) = primal_interp_world(convert(Enzyme.API.CDerivativeMode, mode), world, mi)
6024

61-
Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
62-
end
63-
64-
@inline primal_interp_world(
65-
@nospecialize(::ReverseModeSplit),
66-
world::UInt) = primal_interp_world(Reverse, world)
67-
68-
function primal_return_type_world(
69-
@nospecialize(mode::Mode),
70-
world::UInt,
71-
@nospecialize(TT::Type),
72-
)
73-
Core.Compiler._return_type(primal_interp_world(mode, world), TT)
74-
end
75-
76-
function primal_return_type_world(
77-
@nospecialize(mode::Mode),
78-
world::UInt,
79-
mi::Core.MethodInstance,
80-
)
81-
interp = primal_interp_world(mode, world)
25+
function primal_return_type_world(mode, world, mi)
26+
interp = primal_interp_world(mode, world, mi)
8227
return_type(interp, mi)
8328
end
8429

85-
primal_return_type_world(
86-
@nospecialize(mode::Mode),
87-
world::UInt,
88-
@nospecialize(FT::Type),
89-
@nospecialize(TT::Type),
90-
) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...})
91-
9230
function primal_return_type end
9331

9432
function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type))

0 commit comments

Comments
 (0)