-
-
Couldn't load subscription status.
- Fork 216
Open
Description
I've run into a DimensionMismatch error when broadcasting arrays with generic element types. The general flavor of calculation I want to do is like this:
using Zygote
function okay(a, b)
sum(a .+ b)
end
a = rand(2,5,3)
b = rand(2,5,1,4)
@show Zygote.withgradient(okay, a, b)with normal outputs
Zygote.withgradient(okay, a, b) = (val = 105.52328765180295, grad = ([4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0], [3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0]))
however when I wrap each array element in a named tuple as follows
p(a, b) = a.x + b.x
function mwe(a, b)
sum(p.(a, b))
end
ag = NamedTuple{(:x,)}.(tuple.(a))
bg = NamedTuple{(:x,)}.(tuple.(b))
@show Zygote.withgradient(mwe, ag, bg)I get
ERROR: LoadError: DimensionMismatch: variable with size(x) == (2, 5, 3) cannot have a gradient with size(dx) == (30,)
Stacktrace:
[1] (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{elements::Array{ChainRulesCore.ProjectTo{…}, 3}, axes::Tuple{Base.OneTo{…}, Base.OneTo{…}, Base.OneTo{…}}}})(dx::Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/XAgYn/src/projection.jl:229
[2] _project
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/chainrules.jl:200 [inlined]
[3] unbroadcast(x::Array{@NamedTuple{x::Float64}, 3}, maybethunked_x̄::Array{@NamedTuple{x::Float64}, 4})
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:63
[4] map
@ ./tuple.jl:383 [inlined]
[5] ∇broadcasted
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:222 [inlined]
[6] (::Zygote.var"#4145#back#1372"{Zygote.var"#∇broadcasted#1383"{Tuple{…}, Array{…}, Val{…}}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#719#722"{…}}, ChainRules.var"#718#721"{Float64, Colon}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[7] #305
@ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/lib.jl:214 [inlined]
[8] #2189#back
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
[9] broadcasted
@ ./broadcast.jl:1331 [inlined]
[10] mwe
@ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:13 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float64)
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface2.jl:0
[12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float64)
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:97
[13] withgradient(::Function, ::Array{@NamedTuple{x::Float64}, 3}, ::Vararg{Any})
@ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:219
[14] macro expansion
@ show.jl:1232 [inlined]
[15] top-level scope
@ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
[16] include(fname::String)
@ Main ./sysimg.jl:38
[17] top-level scope
@ REPL[3]:1
in expression starting at /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
Some type information was truncated. Use `show(err)` to see complete types.
I've worked out that an unintended flattening of arrays happens in the following line
Zygote.jl/src/compiler/chainrules.jl
Line 297 in e0af1a8
| z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal) |
here,
map will flatten arrays when I believe they are intended to be broadcasted, at least in the context of unbroadcast. Perhaps unbroadcast should drop trailing singleton dimensions? I am not sure what the correct fix is.Metadata
Metadata
Assignees
Labels
No labels