Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient of cat which introduce new dims do not match the dims of input #1061

Open
chengchingwen opened this issue Sep 4, 2021 · 5 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@chengchingwen
Copy link
Member

MWE:

julia> Zygote.gradient(randn(3,3)) do x
           sum(sin.(cat(x; dims=4)))
       end[1]
3×3×1×1 Array{Float64, 4}: # should be 3 x 3 but get 3 x 3 x 1 x 1
[:, :, 1, 1] =
 0.930559   0.810081  0.894403
 0.951607  -0.659616  0.310079
 0.950346   0.774937  0.910482
@mcabbott
Copy link
Member

mcabbott commented Sep 4, 2021

Fixed in ChainRules, but I think Zygote still uses its own older versions:

julia> Diffractor.gradient(randn(3,3)) do x
               sum(sin.(cat(x; dims=4)))
           end[1]
3×3 Matrix{Float64}:
 0.954619  0.903961  0.297481
 0.174126  0.99136   0.972581
 0.53144   0.439785  0.877705

@chengchingwen
Copy link
Member Author

Seems to be fixed already. Tested with Zygote 0.6.34:

julia> Zygote.gradient(randn(3,3)) do x
                  sum(sin.(cat(x; dims=4)))
              end[1]         
3×3 Matrix{Float64}:        
 0.598398   0.988603  0.999602
 0.84835   -0.101217  0.391286
 0.9785    -0.717415  0.87662

@mcabbott
Copy link
Member

mcabbott commented Feb 12, 2022

FWIW, this is because gradient calls ProjectTo on the final answer. The rule itself is unchanged, and thus intermediate results may show this.

julia> Zygote.pullback(randn(3,3)) do x
           sum(sin.(cat(x; dims=4)))
       end[2](1.0)[1]
3×3×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 0.392695  -0.145669  0.450509
 0.334361   0.316647  0.980656
 0.987435   0.843904  0.985057
 
 julia> gradient(x -> sum(abs2, cat(x * x'; dims=4)), [1 2; 3 4])
ERROR: MethodError: no method matching *(::Array{Int64, 4}, ::Matrix{Int64}) 

@chengchingwen chengchingwen reopened this Feb 12, 2022
@mcabbott
Copy link
Member

IIRC the hurdle to simply deleting all of these is JuliaGPU/GPUArrays.jl#362 . vcat of a mix of numbers and CuArrays mostly works, and its gradient should not use scalar indexing. ChainRules doesn't depend on GPU stuff but it calls sum(view(x,i)) as a work-around... which doesn't work yet. Maybe there's a smarter way.

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Jul 4, 2022
@mzgubic
Copy link
Collaborator

mzgubic commented Aug 1, 2022

This is solved now right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

3 participants