-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Description
Hi! Currently, importing rules into Enzyme are failing on AdvancedVI for some reason. I am currently trying to dig into this, and hit a different error. Here is a MWE:
using ChainRulesCore
using Enzyme
using Statistics
using DifferentiationInterface
using ADTypes
struct Prob
x::Array
y::Array
end
logdensity(prob::Prob, x::AbstractArray) = sum(x) + sum(prob.x)
logdensity_and_gradient(::Prob, x::AbstractArray) = (1., [-1.0, 1.0])
function ChainRulesCore.rrule(
::typeof(logdensity),
mixedad_prob::Prob,
x::AbstractArray,
)
ℓπ, ∇ℓπ = logdensity_and_gradient(mixedad_prob, x)
function logdensity_pullback(∂y::AbstractArray)
∂x = ChainRulesCore.@thunk(∂y' * ∇ℓπ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ∂x
end
return ℓπ, logdensity_pullback
end
function mixedad_test_fwd(x, prob)
xs = repeat(x, 1, 2)
return logdensity(prob, x)
end
Enzyme.@import_rrule(typeof(logdensity), Prob, AbstractArray)
x = randn(2)
dx = zero(x)
adtype = AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
)
prob = Prob(randn(1), randn(1))
DifferentiationInterface.value_and_gradient(mixedad_test_fwd, adtype, x, Constant(prob))On Julia 1.10, this results in
julia> DifferentiationInterface.value_and_gradient(mixedad_test_fwd, adtype, x, Constant(prob))
ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.RevConfigWidth{1, true, true, (false, false, false), true, false}, Const{typeof(logdensity)}, Type{Duplicated{Any}}, Const{Prob}, Duplicated{Vector{Float64}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{Any, Any, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{Float64, Float64, Tuple{Float64, var"#logdensity_pullback#4"{Vector{Float64}}}}
Stacktrace:
[1] mixedad_test_fwd
@ ./REPL[10]:3 [inlined]
[2] mixedad_test_fwd
@ ./REPL[10]:0 [inlined]
[3] augmented_julia_mixedad_test_fwd_346_inner_1wrap
@ ./REPL[10]:0
[4] macro expansion
@ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:5481 [inlined]
[5] enzyme_call
@ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:5015 [inlined]
[6] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/g07uU/src/compiler.jl:4954 [inlined]
[7] autodiff
@ ~/.julia/packages/Enzyme/g07uU/src/Enzyme.jl:408 [inlined]
[8] macro expansion
@ ~/.julia/packages/Enzyme/g07uU/src/sugar.jl:275 [inlined]
[9] gradient
@ ~/.julia/packages/Enzyme/g07uU/src/sugar.jl:262 [inlined]
[10] value_and_gradient
@ ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:253 [inlined]
[11] value_and_gradient(f::typeof(mixedad_test_fwd), backend::AutoEnzyme{…}, x::Vector{…}, contexts::Constant{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:37
[12] top-level scope
@ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.And evidently the error message doesn't quite make sense since it says that EnzymeCore.EnzymeRules.AugmentedReturn{Any, Any, Any} is expected but complains about receiving EnzymeCore.EnzymeRules.AugmentedReturn{Float64, Float64, Tuple{Float64, var"#logdensity_pullback#4"{Vector{Float64}}}} (?)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels