Open
Description
Package Version
Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7
Julia Version
1.8
OS / Environment
OS
Describe the bug
The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.
Steps to Reproduce
using ForwardDiff # an example of a package which only likes one array
using Flux
using Random
using Optimisers
Random.seed!(123)
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
Conv((3, 3), 3 => 5, pad=1, bias=false),
BatchNorm(5, relu),
Conv((3, 3), 5 => 3, stride=16),
)
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));
loss(m, x) = sum(m(x))
rule = Optimisers.Adam(0.001f0, (0.9f0, 0.999f0), 1.1920929f-7)
flat, re = Flux.destructure(model)
st = Optimisers.setup(rule, flat) # state is just one Leaf now
∇flat = ForwardDiff.gradient(flat) do v
loss(re(v), image) # re(v), rebuild a new object like model
end
st, flat = Optimisers.update(st, flat, ∇flat)
@show loss(re(flat),image);
and Zygote version
using Flux
using Random
Random.seed!(123)
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
Conv((3, 3), 3 => 5, pad=1, bias=false),
BatchNorm(5, relu),
Conv((3, 3), 5 => 3, stride=16),
)
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));
loss(m, x) = sum(m(x))
opt = Flux.Adam(0.001f0, (0.9f0, 0.999f0), 1.1920929f-7)
θ = Flux.params(model)
grads = Flux.gradient(θ) do
loss(model, image)
end
Flux.update!(opt, θ, grads)
@show loss(model, image);
Expected Results
sum(model(image)) = -0.33076355f0
loss(model, image) = -5.064876f0
Observed Results
sum(model(image)) = -0.33076355f0
loss(re(flat), image) = -7.7023053f0
Relevant log output
No response