Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff + destructure is different from Zygote, on a model with BatchNorm #2122

Open
lazarusA opened this issue Nov 21, 2022 · 7 comments

Comments

@lazarusA
Copy link

Package Version

Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7

Julia Version

1.8

OS / Environment

OS

Describe the bug

The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.

Steps to Reproduce

using ForwardDiff  # an example of a package which only likes one array
using Flux
using Random
using Optimisers
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

rule = Optimisers.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)

flat, re = Flux.destructure(model)
st = Optimisers.setup(rule, flat)  # state is just one Leaf now

∇flat = ForwardDiff.gradient(flat) do v
    loss(re(v), image) # re(v), rebuild a new object like model
end

st, flat = Optimisers.update(st, flat, ∇flat)
@show loss(re(flat),image);

and Zygote version

using Flux
using Random
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

opt = Flux.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)
θ = Flux.params(model)
grads = Flux.gradient(θ) do 
    loss(model, image)
end

Flux.update!(opt, θ, grads)
@show loss(model, image);

Expected Results

sum(model(image)) = -0.33076355f0
loss(model, image) = -5.064876f0

Observed Results

sum(model(image)) = -0.33076355f0
loss(re(flat), image) = -7.7023053f0

Relevant log output

No response

@lazarusA lazarusA added the bug label Nov 21, 2022
@mcabbott
Copy link
Member

Can reproduce.

I note that commenting out BatchNorm removes the discrepancy.

And that inserting trainmode!(model) produces this error:

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12})
Stacktrace:
...
 [12] _track_stats!(bn::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, μ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, σ²::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, reduce_dims::Vector{Int64})
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:278
 [13] _norm_layer_forward(l::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}; reduce_dims::Vector{Int64}, affine_shape::NTuple{4, Int64})
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:253
...

@ToucheSir
Copy link
Member

For this last error, we'd need some way to pull the value out of the Duals passed to _track_stats!. Alternatively, is there some way to mark that function as non-differentiable?

@mcabbott
Copy link
Member

I think we can harmlessly just insert value into that broadcast. It's ForwardDiff-specific but maybe worth having?

For automatic train-mode, if we do something like FluxML/NNlib.jl#434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.

@ToucheSir
Copy link
Member

So one fly in the ointment is that I was hoping to move track_stats! et al. out to NNlib soonish, which can't rely on ForwardDiff being loaded. Which is a good segue to

For automatic train-mode, if we do something like FluxML/NNlib.jl#434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.

One path would be to be use Requires in NNlib to get non-CR ADs to conform to FluxML/NNlib.jl#434. Another would be adding it to AbstractDifferentiation.jl, which already uses Requires for FD + RD + Tracker. Any other ideas I had (e.g. splitting off Dual numbers from ForwardDiff and having NNlib define methods on them) feel too far off to be feasible.

@mcabbott
Copy link
Member

I had to check but NNlib is much lighter than ForwardDiff, even if that moves to StaticArraysCore. But it does load Requires, so that might be fine:

julia> @time_imports using NNlib
      0.3 ms  Requires
      0.9 ms  DelimitedFiles
      0.3 ms  Compat
     80.0 ms  ChainRulesCore
      0.3 ms  Adapt
     24.3 ms  NNlib 55.88% compilation time (14% recompilation)

@mcabbott mcabbott changed the title Output from ForwardDiff and Optimisers is different from the one given by Zygote ForwardDiff + destructure is different from Zygote, on a model with BatchNorm Nov 23, 2022
@mcabbott mcabbott transferred this issue from FluxML/Optimisers.jl Nov 23, 2022
@ToucheSir
Copy link
Member

ToucheSir commented Nov 24, 2022

AbstractDiff is pretty similar. Let me file an issue over there and see how it goes. We can always look into the NNlib option in parallel.

@mcabbott
Copy link
Member

Besides detecting whether you are within AD (also an issue for dropout), the problem with BatchNorm is that ForwardDiff runs the forward pass several times (chunked mode, for any large array).

I can't think of a good way to detect that. Perhaps we should make it a clear error instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants