Skip to content

Commit 577787d

Browse files
committed
Interp: optionally disable inactive noinline
1 parent 4fc9eb8 commit 577787d

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

src/compiler/interpreter.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
129129

130130
forward_rules::Bool
131131
reverse_rules::Bool
132+
inactive_rules::Bool
132133
broadcast_rewrite::Bool
133134
handler::T
134135
end
@@ -166,6 +167,7 @@ function EnzymeInterpreter(
166167
world::UInt,
167168
forward_rules::Bool,
168169
reverse_rules::Bool,
170+
inactive_rules::Bool,
169171
broadcast_rewrite::Bool = true,
170172
handler = nothing
171173
)
@@ -197,10 +199,12 @@ function EnzymeInterpreter(
197199
end
198200
end
199201

200-
inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)
201-
if !rule_sigs_equal(inarules, LastInaWorld[])
202-
LastInaWorld[] = inarules
203-
invalid = true
202+
if inactive_rules
203+
inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)
204+
if !rule_sigs_equal(inarules, LastInaWorld[])
205+
LastInaWorld[] = inarules
206+
invalid = true
207+
end
204208
end
205209

206210
if invalid
@@ -223,6 +227,7 @@ function EnzymeInterpreter(
223227
OptimizationParams(),
224228
forward_rules,
225229
reverse_rules,
230+
inactive_rules,
226231
broadcast_rewrite,
227232
handler
228233
)
@@ -364,20 +369,12 @@ function Core.Compiler.abstract_call_gf_by_type(
364369
callinfo = AlwaysInlineCallInfo(callinfo, atype)
365370
else
366371
method_table = Core.Compiler.method_table(interp)
367-
if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
372+
if interp.inactive_rules && EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
368373
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
369-
else
370-
if interp.forward_rules
371-
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
372-
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
373-
end
374-
end
375-
376-
if interp.reverse_rules
377-
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
378-
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
379-
end
380-
end
374+
elseif interp.forward_rules && EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
375+
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
376+
elseif interp.reverse_rules && EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
377+
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
381378
end
382379
end
383380

src/typeutils/inference.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ function primal_interp_world(
2727
EnzymeCompilerParams,
2828
world,
2929
false,
30+
true,
3031
true
3132
)
3233
else
@@ -50,7 +51,8 @@ function primal_interp_world(
5051
EnzymeCompilerParams,
5152
world,
5253
true,
53-
false
54+
false,
55+
true
5456
)
5557
else
5658
Enzyme.Compiler.GLOBAL_FWD_CACHE

0 commit comments

Comments
 (0)