Skip to content

ForwardDiff + destructure is different from Zygote, on a model with BatchNorm #2122

Open
@lazarusA

Description

@lazarusA

Package Version

Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7

Julia Version

1.8

OS / Environment

OS

Describe the bug

The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.

Steps to Reproduce

using ForwardDiff  # an example of a package which only likes one array
using Flux
using Random
using Optimisers
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

rule = Optimisers.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)

flat, re = Flux.destructure(model)
st = Optimisers.setup(rule, flat)  # state is just one Leaf now

∇flat = ForwardDiff.gradient(flat) do v
    loss(re(v), image) # re(v), rebuild a new object like model
end

st, flat = Optimisers.update(st, flat, ∇flat)
@show loss(re(flat),image);

and Zygote version

using Flux
using Random
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

opt = Flux.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)
θ = Flux.params(model)
grads = Flux.gradient(θ) do 
    loss(model, image)
end

Flux.update!(opt, θ, grads)
@show loss(model, image);

Expected Results

sum(model(image)) = -0.33076355f0
loss(model, image) = -5.064876f0

Observed Results

sum(model(image)) = -0.33076355f0
loss(re(flat), image) = -7.7023053f0

Relevant log output

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions