Skip to content

Imported rule type mismatch? #2520

@Red-Portal

Description

@Red-Portal

Hi! Currently, importing rules into Enzyme are failing on AdvancedVI for some reason. I am currently trying to dig into this, and hit a different error. Here is a MWE:

using ChainRulesCore
using Enzyme
using Statistics
using DifferentiationInterface
using ADTypes

struct Prob 
  x::Array
  y::Array
end

logdensity(prob::Prob, x::AbstractArray) = sum(x) + sum(prob.x)

logdensity_and_gradient(::Prob, x::AbstractArray) = (1., [-1.0, 1.0])

function ChainRulesCore.rrule(
    ::typeof(logdensity),
    mixedad_prob::Prob,
    x::AbstractArray,
)
    ℓπ, ∇ℓπ = logdensity_and_gradient(mixedad_prob, x)
    function logdensity_pullback(∂y::AbstractArray)
        ∂x = ChainRulesCore.@thunk(∂y' * ∇ℓπ)
        return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ∂x
    end
    return ℓπ, logdensity_pullback
end

function mixedad_test_fwd(x, prob)
    xs = repeat(x, 1, 2)
    return logdensity(prob, x)
end

Enzyme.@import_rrule(typeof(logdensity), Prob, AbstractArray)

x = randn(2)
dx = zero(x)
adtype = AutoEnzyme(;
    mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
    function_annotation=Enzyme.Const,
)
prob = Prob(randn(1), randn(1))
DifferentiationInterface.value_and_gradient(mixedad_test_fwd, adtype, x, Constant(prob))

On Julia 1.10, this results in

julia> DifferentiationInterface.value_and_gradient(mixedad_test_fwd, adtype, x, Constant(prob))
ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.RevConfigWidth{1, true, true, (false, false, false), true, false}, Const{typeof(logdensity)}, Type{Duplicated{Any}}, Const{Prob}, Duplicated{Vector{Float64}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{Any, Any, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{Float64, Float64, Tuple{Float64, var"#logdensity_pullback#4"{Vector{Float64}}}}

Stacktrace:
  [1] mixedad_test_fwd
    @ ./REPL[10]:3 [inlined]
  [2] mixedad_test_fwd
    @ ./REPL[10]:0 [inlined]
  [3] augmented_julia_mixedad_test_fwd_346_inner_1wrap
    @ ./REPL[10]:0
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:5481 [inlined]
  [5] enzyme_call
    @ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:5015 [inlined]
  [6] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:4954 [inlined]
  [7] autodiff
    @ ~/.julia/packages/Enzyme/g07uU/src/Enzyme.jl:408 [inlined]
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/g07uU/src/sugar.jl:275 [inlined]
  [9] gradient
    @ ~/.julia/packages/Enzyme/g07uU/src/sugar.jl:262 [inlined]
 [10] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:253 [inlined]
 [11] value_and_gradient(f::typeof(mixedad_test_fwd), backend::AutoEnzyme{…}, x::Vector{…}, contexts::Constant{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:37
 [12] top-level scope
    @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

And evidently the error message doesn't quite make sense since it says that EnzymeCore.EnzymeRules.AugmentedReturn{Any, Any, Any} is expected but complains about receiving EnzymeCore.EnzymeRules.AugmentedReturn{Float64, Float64, Tuple{Float64, var"#logdensity_pullback#4"{Vector{Float64}}}} (?)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions