Skip to content

gradient of cat which introduce new dims do not match the dims of input #1061

Open
@chengchingwen

Description

@chengchingwen

MWE:

julia> Zygote.gradient(randn(3,3)) do x
           sum(sin.(cat(x; dims=4)))
       end[1]
3×3×1×1 Array{Float64, 4}: # should be 3 x 3 but get 3 x 3 x 1 x 1
[:, :, 1, 1] =
 0.930559   0.810081  0.894403
 0.951607  -0.659616  0.310079
 0.950346   0.774937  0.910482

Metadata

Metadata

Assignees

No one assigned

    Labels

    ChainRulesadjoint -> rrule, and further integration

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions