Skip to content

size error when unbroadcasting arrays with generic eltypes #1585

@lxvm

Description

@lxvm

I've run into a DimensionMismatch error when broadcasting arrays with generic element types. The general flavor of calculation I want to do is like this:

using Zygote

function okay(a, b)
    sum(a .+ b)
end

a = rand(2,5,3)
b = rand(2,5,1,4)
@show Zygote.withgradient(okay, a, b)

with normal outputs

Zygote.withgradient(okay, a, b) = (val = 105.52328765180295, grad = ([4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0;;; 4.0 4.0 4.0 4.0 4.0; 4.0 4.0 4.0 4.0 4.0], [3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0;;;; 3.0 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0 3.0]))

however when I wrap each array element in a named tuple as follows

p(a, b) = a.x + b.x
function mwe(a, b)
    sum(p.(a, b))
end

ag = NamedTuple{(:x,)}.(tuple.(a))
bg = NamedTuple{(:x,)}.(tuple.(b))
@show Zygote.withgradient(mwe, ag, bg)

I get

ERROR: LoadError: DimensionMismatch: variable with size(x) == (2, 5, 3) cannot have a gradient with size(dx) == (30,)
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{elements::Array{ChainRulesCore.ProjectTo{…}, 3}, axes::Tuple{Base.OneTo{…}, Base.OneTo{…}, Base.OneTo{…}}}})(dx::Vector{ChainRulesCore.Tangent{@NamedTuple{x::Float64}, @NamedTuple{x::Float64}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/XAgYn/src/projection.jl:229
  [2] _project
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/chainrules.jl:200 [inlined]
  [3] unbroadcast(x::Array{@NamedTuple{x::Float64}, 3}, maybethunked_x̄::Array{@NamedTuple{x::Float64}, 4})
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:63
  [4] map
    @ ./tuple.jl:383 [inlined]
  [5] ∇broadcasted
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/broadcast.jl:222 [inlined]
  [6] (::Zygote.var"#4145#back#1372"{Zygote.var"#∇broadcasted#1383"{Tuple{…}, Array{…}, Val{…}}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#719#722"{…}}, ChainRules.var"#718#721"{Float64, Colon}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [7] #305
    @ /local/home/lxvm/projects/contrib/dev/Zygote/src/lib/lib.jl:214 [inlined]
  [8] #2189#back
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72 [inlined]
  [9] broadcasted
    @ ./broadcast.jl:1331 [inlined]
 [10] mwe
    @ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:13 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float64)
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(mwe), Array{@NamedTuple{…}, 3}, Array{@NamedTuple{…}, 4}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{}}, Zygote.ZBack{ChainRules.var"#sum_pullback#720"{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float64)
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:97
 [13] withgradient(::Function, ::Array{@NamedTuple{x::Float64}, 3}, ::Vararg{Any})
    @ Zygote /local/home/lxvm/projects/contrib/dev/Zygote/src/compiler/interface.jl:219
 [14] macro expansion
    @ show.jl:1232 [inlined]
 [15] top-level scope
    @ /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
 [16] include(fname::String)
    @ Main ./sysimg.jl:38
 [17] top-level scope
    @ REPL[3]:1
in expression starting at /local/home/lxvm/projects/contrib/zygote_mwe2.jl:18
Some type information was truncated. Use `show(err)` to see complete types.

I've worked out that an unintended flattening of arrays happens in the following line

z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal)

here, map will flatten arrays when I believe they are intended to be broadcasted, at least in the context of unbroadcast. Perhaps unbroadcast should drop trailing singleton dimensions? I am not sure what the correct fix is.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions