Use zygote2differential to wrap chainrules inputs#1057
Use zygote2differential to wrap chainrules inputs#1057
Conversation
|
well this breaks some tests in weird ways. |
|
I just checked TuringLang/DistributionsAD.jl#198 (the CR1 version) locally and it still fails with the same error messages ("adjoint for constructor ..."), even with this PR. |
|
Yeah won't fix that. |
|
There was a matching differential2zygote that@mzgubic wrote. |
|
Ah sorry, I misunderstood your comment. Unfortunately, the example is not fixed either. |
|
Here it is, in case you find it useful: differential2legacy(x) = unthunk(x) # TODO eventually remove this
differential2legacy(::AbstractZero) = nothing
differential2legacy(t::Union{Tuple, NamedTuple}) = map(differential2legacy, t)
differential2legacy(::Nothing) = (legacytype_warn(Nothing); return nothing)
differential2legacy(a::AbstractArray) = differential2legacy.(a) # TODO: what to do with arrays with nothing?
differential2legacy(a::AbstractArray{<:Number}) = a
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function differential2legacy(x::Composite{P, T}) where {P, T<:$T_outer}
xp = map(differential2legacy, canonicalize(x))
convert($T_outer, xp)
end
endI do recall getting into some kind of trouble when using this instead of |
@mzgubic implemented zygote2differential as a better version of wrap_chainrules_inputs and added it to use in the code for
rrule_via_ad.But it was not added to the normal path for when Zygote uses ChainRules.
I guess because it requires keeping the primal values in memory.
Which is probably a lot?
Anyway this would give us more consistent chainrules types.
No more
Tangent{Any}ornothingsthat are hidden with-in arrays.We probably do not want to merge this as is because of the extra memory use.
or maybe it is not too bad. Do we have a benchmark for it?
But hopefully this will fix the problems in TuringLang/DistributionsAD.jl#197
cc @devmotion .
If it does we can look at reworking
zygote2differentialto not have to store so much.We learnt a lot about doing that for
ProjectTosame techniques can be applied here.
NB: I am putting this PR up at 9:30 at night, and I have not even run it locally.
Might have typos etc and just not work.
It also has no tests, yet.