Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote can't differentiate matrix literal with double semicolon #1413

Open
white-alistair opened this issue Apr 14, 2023 · 2 comments
Open
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@white-alistair
Copy link

MWE

using Zygote

function mul_good(x)
    A = [x 1]
    b = [2, 3]
    return (A * b)[1]
end

function mul_bad(x)
    A = [x;; 1]
    b = [2, 3]
    return (A * b)[1]
end

Zygote.gradient(mul_good, 1)  # (2.0,)
Zygote.gradient(mul_bad, 1)   # ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Int64}, ...)

Full stacktrace

ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Int64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Int64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:88
  [3] (::Zygote.var"#550#551"{Matrix{Int64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:100
  [4] (::Zygote.var"#2620#back#552"{Zygote.var"#550#551"{Matrix{Int64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [5] Pullback
    @ ./abstractarray.jl:2355 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(Base.hvncat_fill!), Matrix{Int64}, Bool, Tuple{Int64, Int64}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./abstractarray.jl:2227 [inlined]
  [8] (::Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Val{2}, Int64, Int64}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [9] (::Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Val{2}, Int64, Int64}, Any}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206
 [10] #2138#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [11] Pullback
    @ ./abstractarray.jl:2215 [inlined]
 [12] (::Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Int64, Bool, Int64, Int64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Val{2}, Int64, Int64}, Any}}}}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [13] #287
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
 [14] #2138#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [15] Pullback
    @ ./abstractarray.jl:2193 [inlined]
 [16] (::Zygote.Pullback{Tuple{typeof(Base._hvncat), Int64, Bool, Int64, Int64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Int64, Bool, Int64, Int64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Val{2}, Int64, Int64}, Any}}}}}}}}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [17] #287
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
 [18] #2138#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [19] Pullback
    @ ./abstractarray.jl:2189 [inlined]
 [20] Pullback
    @ ~/code/NeuralDiffEqTools.jl/zygote_bug.jl:10 [inlined]
 [21] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mul_bad), Int64}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1486"{Matrix{Int64}, Vector{Int64}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1368"{2, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(hvncat), Int64, Int64, Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._hvncat), Int64, Bool, Int64, Int64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Int64, Bool, Int64, Int64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#1891#back#157"{Zygote.var"#153#156"}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base._typed_hvncat), Type{Int64}, Val{2}, Int64, Int64}, Any}}}}}}}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [22] gradient(f::Function, args::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
 [23] top-level scope
    @ ~/code/NeuralDiffEqTools.jl/zygote_bug.jl:16

Version Info

Zygote v0.6.60

Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, tigerlake)
  Threads: 1 on 8 virtual cores
Environment:
  LD_PRELOAD = /usr/lib/x86_64-linux-gnu/libstdc++.so.6
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 
@white-alistair
Copy link
Author

Seems related to #513

@white-alistair white-alistair changed the title Zygote can't differentiate through matrix literal with double semicolon Zygote can't differentiate matrix literal with double semicolon Apr 14, 2023
@ToucheSir
Copy link
Member

Somewhat. Perhaps confusingly, adding more semicolons lowers to a completely different function:

julia> Meta.@lower [a;; b]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Base.hvncat(2, a, b)
└──      return %1
))))

A rule for hvncat in ChainRules would be the ideal solution. If not there, then we can consider one for Zygote.

@ToucheSir ToucheSir added the ChainRules adjoint -> rrule, and further integration label Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

2 participants