-
-
Couldn't load subscription status.
- Fork 37
Description
This particular issue is really centered around Zygote, but I'm guessing that IRTools is part of the cause.
In my probabilistic programming framework, I have learnable parameters:
function learnable_hypers()
l = learnable(:l, Float64[4.0, 4.0])
m = learnable(:m, 10.0)
q = rand(:q, Normal(l[1], 1.0 + exp(m)))
return q
endfor which gradients can be computed. But I'm having drastically different behavior depending on type annotations for parameters with Array values.
In particular, the above works correctly with my pullbacks defined here. However, if I change the type annotation:
function learnable_hypers()
l = learnable(:l, Any[4.0, 4.0])
m = learnable(:m, 10.0)
q = rand(:q, Normal(l[1], 1.0 + exp(m)))
return q
endI get an array mutation error:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#1048#1049")(::Nothing) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/lib/array.jl:61
[3] (::Zygote.var"#2775#back#1050"{Zygote.var"#1048#1049"})(::Nothing) at /home/mccoy/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] hvcat_fill at ./abstractarray.jl:1707 [inlined]
[5] (::typeof(∂(λ)))(::Nothing) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[6] typed_hvcat at ./abstractarray.jl:1729 [inlined]
[7] (::typeof(∂(λ)))(::Nothing) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[8] learnable_hypers at /home/mccoy/.julia/dev/Jaynes/scratch/learnable.jl:8 [inlined]
[9] (::typeof(∂(λ)))(::Float64) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[10] #174 at /home/mccoy/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
[11] #347#back at /home/mccoy/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[12] #84 at /home/mccoy/.julia/dev/Jaynes/src/contexts/backpropagate.jl:152 [inlined]
[13] (::typeof(∂(λ)))(::Tuple{Float64,Float64}) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[14] (::Zygote.var"#36#37"{typeof(∂(λ))})(::Tuple{Float64,Float64}) at /home/mccoy/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:46
[15] accumulate_parameter_gradients!(::Main.Learnable.Jaynes.Gradients, ::Main.Learnable.Jaynes.BlackBoxCallSite{Main.Learnable.Jaynes.HierarchicalTrace,Tuple{},Float64}, ::Float64, ::Float64) at /home/mccoy/.julia/dev/Jaynes/src/contexts/backpropagate.jl:157
[16] get_parameter_gradients at /home/mccoy/.julia/dev/Jaynes/src/contexts/backpropagate.jl:204 [inlined]
[17] get_parameter_gradients(::Main.Learnable.Jaynes.BlackBoxCallSite{Main.Learnable.Jaynes.HierarchicalTrace,Tuple{},Float64}, ::Float64) at /home/mccoy/.julia/dev/Jaynes/src/contexts/backpropagate.jl:203
Wierdly enough, I run into the same issue if I type annotate multi-dimensional arrays
function learnable_hypers()
l = learnable(:l, Float64[4.0 4.0; 4.0 4.0])
m = learnable(:m, 10.0)
q = rand(:q, Normal(l[1], 1.0 + exp(m)))
return q
endSo 1-dimensional Float64 arrays seem to work, but everything else will not go. Path forward might be to analyze what's going on with Cthulhu.