Skip to content

Non-specialized methods with custom rules result in unnecessary use of runtime handlers and "Non-constant keyword argument" error

Open

Description

Successor to #1845

Julia avoids specializing methods on arguments in certain cases, most notably when the argument type is <: Function and the function is not called in the function body, but only passed through to an inner function. This does not block type inference, only code generation, and runtime dispatch is often avoided by inlining since "pass-through" methods are usually small.

However, if a custom rule is written for such a method, Enzyme sees it as type unstable and invokes the runtime handler with its limited activity analysis. In addition to the performance penalty, this throws an error if activity analysis fails to prove that a keyword argument is Const. An important example is the custom rules for QuadGK.jl, since quadgk takes both a function argument and non-active float keyword arguments to set tolerances.

The solution could be for Enzyme to force recompilation with full specialization before choosing runtime vs. compile-time handling. This seems possible for a package like Enzyme, and would be fair game: I'm certain no one would object to this little bit of extra compilation in exchange for a faster and non-erroring gradient.

Reproducer below. Adding a type variable f::F to force specialization works around the issue.

using Enzyme

constcall(a, info) = call(() -> a; info)

function call(f; info=nothing)               # errors
# function call(f::F; info=nothing) where {F}  # works
    @info "$info"  # must use `info` somehow for the error to appear
    return f()
end

function EnzymeRules.augmented_primal(
    config, ::Const{typeof(call)}, ::Type{<:Active}, f::Active; kws...,
)
    primal = EnzymeRules.needs_primal(config) ? call(f.val; kws...) : nothing
    return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(  # this rule is totally wrong, but that's beside the point
    config, ::Const{typeof(call)}, ::Active, tape, f::Active; kws...,
)
    return (f.val,)
end

@show constcall(1.0, 1e-10)
@show autodiff(Reverse, constcall, Active, Active(1.0), Const(1e-10))

Output:

[ Info: 1.0e-10
constcall(1.0, 1.0e-10) = 1.0
ERROR: LoadError: Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{info::Float64}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false), false}, Const{typeof(call)},
 Type{Active{Float64}}, Active{var"#61#62"{Float64}}}

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
  [3] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6552 [inlined]
  [4] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Core.kwcall), df::Nothing, primal_1::@NamedTuple{…}, shadow_1_1::Base.RefValue{…}, primal_2::typeof(call), shadow_2_1::Nothing, primal_3::var"#61#62"{…}, shadow_3_1::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/jitrules.jl:368
  [5] constcall
    @ ~/issues/quadgkkwargs.jl:41 [inlined]
  [6] diffejulia_constcall_11499wrap
    @ ~/issues/quadgkkwargs.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
  [9] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6541 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:316 [inlined]
 [11] autodiff(::ReverseMode{…}, ::typeof(constcall), ::Type{…}, ::Active{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:328
 [12] macro expansion
    @ show.jl:1181 [inlined]
 [13] top-level scope
    @ ~/issues/quadgkkwargs.jl:63
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [15] top-level scope
    @ REPL[3]:1
in expression starting at /home/daniel/issues/quadgkkwargs.jl:63
Some type information was truncated. Use `show(err)` to see complete types.

(PS: This reproducer is somewhat deceptive in that call calls f in the body, so why is it still not specialized? My understanding is that the inner method is specialized, but not the keyword handling wrapper that is actually invoked by call(f; info).)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions