Description
openedon Sep 21, 2024
Successor to #1845
Julia avoids specializing methods on arguments in certain cases, most notably when the argument type is <: Function
and the function is not called in the function body, but only passed through to an inner function. This does not block type inference, only code generation, and runtime dispatch is often avoided by inlining since "pass-through" methods are usually small.
However, if a custom rule is written for such a method, Enzyme sees it as type unstable and invokes the runtime handler with its limited activity analysis. In addition to the performance penalty, this throws an error if activity analysis fails to prove that a keyword argument is Const
. An important example is the custom rules for QuadGK.jl, since quadgk
takes both a function argument and non-active float keyword arguments to set tolerances.
The solution could be for Enzyme to force recompilation with full specialization before choosing runtime vs. compile-time handling. This seems possible for a package like Enzyme, and would be fair game: I'm certain no one would object to this little bit of extra compilation in exchange for a faster and non-erroring gradient.
Reproducer below. Adding a type variable f::F
to force specialization works around the issue.
using Enzyme
constcall(a, info) = call(() -> a; info)
function call(f; info=nothing) # errors
# function call(f::F; info=nothing) where {F} # works
@info "$info" # must use `info` somehow for the error to appear
return f()
end
function EnzymeRules.augmented_primal(
config, ::Const{typeof(call)}, ::Type{<:Active}, f::Active; kws...,
)
primal = EnzymeRules.needs_primal(config) ? call(f.val; kws...) : nothing
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end
function EnzymeRules.reverse( # this rule is totally wrong, but that's beside the point
config, ::Const{typeof(call)}, ::Active, tape, f::Active; kws...,
)
return (f.val,)
end
@show constcall(1.0, 1e-10)
@show autodiff(Reverse, constcall, Active, Active(1.0), Const(1e-10))
Output:
[ Info: 1.0e-10
constcall(1.0, 1.0e-10) = 1.0
ERROR: LoadError: Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{info::Float64}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false), false}, Const{typeof(call)},
Type{Active{Float64}}, Active{var"#61#62"{Float64}}}
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
[2] enzyme_call
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
[3] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6552 [inlined]
[4] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Core.kwcall), df::Nothing, primal_1::@NamedTuple{…}, shadow_1_1::Base.RefValue{…}, primal_2::typeof(call), shadow_2_1::Nothing, primal_3::var"#61#62"{…}, shadow_3_1::Base.RefValue{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/jitrules.jl:368
[5] constcall
@ ~/issues/quadgkkwargs.jl:41 [inlined]
[6] diffejulia_constcall_11499wrap
@ ~/issues/quadgkkwargs.jl:0
[7] macro expansion
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
[8] enzyme_call
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
[9] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6541 [inlined]
[10] autodiff
@ ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:316 [inlined]
[11] autodiff(::ReverseMode{…}, ::typeof(constcall), ::Type{…}, ::Active{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:328
[12] macro expansion
@ show.jl:1181 [inlined]
[13] top-level scope
@ ~/issues/quadgkkwargs.jl:63
[14] include(fname::String)
@ Base.MainInclude ./client.jl:489
[15] top-level scope
@ REPL[3]:1
in expression starting at /home/daniel/issues/quadgkkwargs.jl:63
Some type information was truncated. Use `show(err)` to see complete types.
(PS: This reproducer is somewhat deceptive in that call
calls f
in the body, so why is it still not specialized? My understanding is that the inner method is specialized, but not the keyword handling wrapper that is actually invoked by call(f; info)
.)