Skip to content

Commit cce7cfe

Browse files
committed
Enzyme respect GPUCompiler.optimization_params
1 parent 986bbf4 commit cce7cfe

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/compiler.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ if VERSION >= v"1.11.0-DEV.1552"
185185
Interpreter.EnzymeInterpreter(
186186
GPUCompiler.ci_cache_token(job),
187187
GPUCompiler.method_table(job),
188+
GPUCompiler.inference_params(job),
189+
GPUCompiler.optimization_params(job),
188190
job.world,
189191
job.config.params.mode,
190192
true
@@ -211,6 +213,8 @@ else
211213
Interpreter.EnzymeInterpreter(
212214
enzyme_ci_cache(job),
213215
GPUCompiler.method_table(job),
216+
GPUCompiler.inference_params(job),
217+
GPUCompiler.optimization_params(job),
214218
job.world,
215219
job.config.params.mode,
216220
true

src/compiler/interpreter.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ function EnzymeInterpreter(
124124
cache_or_token,
125125
mt::Union{Nothing,Core.MethodTable},
126126
world::UInt,
127+
inf_params::InferenceParams,
128+
opt_params::OptimizationParams,
127129
forward_rules::Bool,
128130
reverse_rules::Bool,
129131
inactive_rules::Bool,
@@ -133,11 +135,11 @@ function EnzymeInterpreter(
133135
)
134136
@assert world <= Base.get_world_counter()
135137

136-
parms = @static if VERSION >= v"1.12.0-DEV.1017"
137-
InferenceParams()
138-
else
139-
InferenceParams(; unoptimize_throw_blocks=false)
140-
end
138+
# parms = @static if VERSION >= v"1.12.0-DEV.1017"
139+
# InferenceParams()
140+
# else
141+
# InferenceParams(; unoptimize_throw_blocks=false)
142+
# end
141143

142144
@static if HAS_INTEGRATED_CACHE
143145

@@ -171,10 +173,11 @@ function EnzymeInterpreter(
171173
Base.empty!(cache_or_token)
172174
end
173175
end
176+
method_table = mt == nothing ? Core.Compiler.InternalMethodTable(world) : Core.Compiler.OverlayMethodTable(world, mt),
174177

175178
return EnzymeInterpreter(
176179
cache_or_token,
177-
mt == nothing ? Core.Compiler.InternalMethodTable(world) : Core.Compiler.OverlayMethodTable(world, mt),
180+
method_table,
178181

179182
# Initially empty cache
180183
Vector{InferenceResult}(),
@@ -183,8 +186,8 @@ function EnzymeInterpreter(
183186
world,
184187

185188
# parameters for inference and optimization
186-
parms,
187-
OptimizationParams(),
189+
inf_params,
190+
opt_params,
188191
forward_rules::Bool,
189192
reverse_rules::Bool,
190193
inactive_rules::Bool,

0 commit comments

Comments
 (0)