Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug that caused Flux.params(x) call to not be cached (Closes issue #2040) #2048

Closed

Conversation

christiangnrd
Copy link
Contributor

Should this PR be accepted, it would close issue #2040 which was caused by constant recompilation causing massive slowdowns when using the GPU.

I am pretty sure that this has the exact same behaviour as in v0.13.5. I reverted to the v0.13.4 version of params!() and added a check for isleaf(x) when x is an AbstractArray{<:Number}.

Let me know if any changes need to be made.

@christiangnrd christiangnrd changed the title Fix bug that caused Flux.params(x) call to not be cached (Closes issue #2040 Fix bug that caused Flux.params(x) call to not be cached (Closes issue #2040) Aug 23, 2022
src/functor.jl Outdated Show resolved Hide resolved
@christiangnrd
Copy link
Contributor Author

I implemented the DenseArray solution, performance is still fixed, and the tests all passed locally.

@christiangnrd
Copy link
Contributor Author

christiangnrd commented Aug 24, 2022

I added a test that ensures that params() deals with transposed and adjoint arrays properly. I based the behaviour on the current 0.13.5 behaviour. It did not work with AbstractArray, but the tests pass with the DenseArray commit.

@ToucheSir
Copy link
Member

I turned on downstream tests for this commit since params is a widely used and public interface. Will merge if those look good!

@@ -36,16 +36,14 @@ Possible values include:
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)

params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a DenseArray{<:Number} ever not be a leaf?

Suggested change
params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
params!(p::Params, x::DenseArray{<:Number}, seen) = push!(p, x)

Copy link
Member

@ToucheSir ToucheSir Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, Base.Experimental.Const is a pure wrapper type and subtypes DenseArray. I've seen it used in JuliaGPU libraries, but am unsure if those would ever come in contact with params.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting type. But I guess it will always be leaf-like, Functors should treat it as it would an SArray right?

More broadly if this method has a test for isleaf, then it has to do something with the other branch. And then it's the other method. I guess it could assert isleaf just to make sure you get an error if someone does something really weird.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has to be a non-leaf for the same reason Transpose does: shared inner arrays.

RE the other branch, I thought the latest change addressed that but it appears I misremembered. Silently dropping an array instead of recursing is definitely not good.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess. Although transposing one of two shared arrays is common, but marking as Const only one of the two seems perverse.

Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me update my suggestion. I think this ought to be safe, and will at least throw an error should someone ever @functor Base.Experimental.Const (or its CUDA analogue):

Suggested change
params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
function params!(p::Params, x::DenseArray{<:Number}, seen = IdSet())
# Fast path for the most common case, Array & CuArray. Solves issue 2040.
Functors.isleaf(x) || error("For efficiency, params believes every DenseArray of numbers is leaflike")
push!(p, x)
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code returns size.(Flux.params((x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)))) == [(1, 2)].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code and what else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code with the above suggestion.

Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. On this example, the suggestion changes nothing compared to the PR. It just moves the isleaf test to be an error not an ignore.

I think such a fast method should exist alongside the method which was here before the PR, which handles all cases (but has more branches). That should be correct. Whether it still solves 2040 I don't know.


push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I have a leaf type which isn't a DenseArray? The current behaviour is:

julia> using NamedDims, StaticArrays

julia> Flux.params((SA[2.2], 3:4.0, NamedDimsArray([5.0], :x)))
Params([[2.2], 3.0:1.0:4.0, NamedDimsArray([5.0], :x)])

What I meant with the DenseArray idea was that this method could be a short-cut for the common case, in addition to the existing method.

Of course the tests as always don't try very hard. But I do think that it ought to keep working with wrappers like NamedDims.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any wrappers already on the dependency chain which have this same behaviour outside of NamedDims?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe SubArray? ReshapedArray, SymTridiagonal ... for tests I guess you want something unlikely to be @functor-ed in the future.

julia> Flux.params(view([1,2,3]pi, 1:2))
Params([[3.141592653589793, 6.283185307179586]])

julia> ans[1] isa DenseArray
false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubArrays are a bit of a landmine IMO because they don't "cover" the entirety of the wrapped array. ReshapedArray makes sense though. Was it that or PermutedDimsArray that we found couldn't have its transform easily reversed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC ReshapedArray was the tricky one, as its type doesn't have the shape.

@christiangnrd
Copy link
Contributor Author

@ToucheSir @mcabbott This is great! I'm learning a lot reading your discussions. I'm quite new to Flux and its inner workings. I'm more than happy to continue working on this pull request, but I'll need guidance.
What is my next step, and should we be adding tests that will cover all the use-cases being mentioned here?

@mcabbott
Copy link
Member

IMO you should probably restore the method to work the way it did, but add to this "slow" case a fast path which will be taken by ordinary CuArrays.

As you can see, Brian and I have been down the rabbit hole of what should be leaflike before... but I think that testing something like Flux.params((x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi))) should be fine.

@christiangnrd
Copy link
Contributor Author

christiangnrd commented Aug 25, 2022

I'm adding the test that you mentioned. I checked the behaviour with 0.13.4 and 0.13.5, and I'm wondering which one is the intended behaviour.

0.13.4: size.(Flux.params((x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)))) == [(2,), (2, 1)]

0.13.5: size.(Flux.params((x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)))) == [(2,), (1, 2)]

Both find all the parameters, but the shapes are different.

@mcabbott
Copy link
Member

I think that's expected. With 0.13.5, Functors 0.3, it recurses inside Tangent (to see W and W' as the same parameter), while before it didn't.

@christiangnrd
Copy link
Contributor Author

christiangnrd commented Aug 25, 2022

Okay I'll add the test to check for 0.13.5 behaviour.

IMO you should probably restore the method to work the way it did, but add to this "slow" case a fast path which will be taken by ordinary CuArrays.

This wouldn't fix the issue on CPU's though. The root issue seems to be that the compiler doesn't cache the result of the function call when that call contains if statements.

Now that I understand the cause, I can easily work around it by calling Flux.params(decoder) at the beginning and passing it in as an argument.

Should I continue working on this or should I submit a pr with just the new test and cose this issue?

@ToucheSir
Copy link
Member

ToucheSir commented Aug 25, 2022

This wouldn't fix the issue on CPU's though. The root issue seems to be that the compiler doesn't cache the result of the function call when that call contains if statements.

Have you verified this (for the proposed solution, that is)? It certainly isn't true in general (otherwise most Flux models would have this problem), so tweaking things a bit may be all that's required.

@christiangnrd
Copy link
Contributor Author

christiangnrd commented Aug 25, 2022

Good point, I'll edit my comment to specify that it's only in situations where Flux.params() is called many times (like in case of the regularization in my #2040 example).

I tested with both cpu and gpu, with my convolutional variational autoencoder and a variational autoencoder (Dense layers only), and that regularization that calls Flux.params() every loss function call causes every step to spend 60-80% of the time compiling for every step instead of the first 1-2 like was the case in 0.13.4.

@ToucheSir
Copy link
Member

I should clarify: we know that is true for the current implementation on 0.13.5/master, but is it still true if you implement the suggestion here?

@christiangnrd
Copy link
Contributor Author

Oh I see, so would that look something like this:

params!(p::Params, x::CuArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)

function params!(p::Params, x, seen = IdSet())
  if x isa AbstractArray{<:Number} && Functors.isleaf(x)
    return push!(p, x)
  elseif x in seen
    nothing
  else
    push!(seen, x)
    for child in trainable(x)
      params!(p, child, seen)
    end
  end
end

I tried it and it's still spending most of each step compiling.
If that's not what you had in mind, let me know and I'll fix it.

@mcabbott
Copy link
Member

work around it by calling Flux.params(decoder) at the beginning and passing it in as an argument.

If I understand right, you have something like loss + sum(norm, params(model)) inside the gradient call. And changing this to loss + sum(norm, ps) where ps = params(model) is outside the gradient works better.

If that's true, then maybe we shouldn't be trying to make the construction of params AD-friendly, we should just hide it completely from AD. Will something like this work?

function params(m...)
  ps = Params()
  ignore_derivatives() do
    params!(ps, m)
  end
  return ps
end
julia> using Flux: params, gradient

julia> model = (x=rand(3), y=rand(3)); tot(m) = sum(m.x + m.y);

julia> g = gradient(params(model)) do
         tot(model)
       end
Grads(...)

julia> g[model.x]
3-element Fill{Float64}, with entries equal to 1.0

julia> g2 = gradient(params(model)) do
         tot(model) + sum(sum, params(model))
       end
Grads(...)

julia> g2[model.x]
3-element Fill{Float64}, with entries equal to 2.0

julia> g3 = gradient(params(model)) do
         tot(model) + sum(sum, ps)
       end
Grads(...)

julia> g3[model.x]
3-element Fill{Float64}, with entries equal to 2.0

julia> @eval Zygote function params(m...)
         ps = Params()
         ignore_derivatives() do
           params!(ps, m)
         end
         return ps
       end
params (generic function with 1 method)

julia> Zygote.refresh()

julia> g4 = gradient(params(model)) do
         tot(model) + sum(sum, params(model))
       end
Grads(...)

julia> g4[model.x]
3-element Fill{Float64}, with entries equal to 2.0

@ToucheSir
Copy link
Member

ToucheSir commented Aug 27, 2022

I'm all for this if it can be done while keeping existing nested AD code working. Tricky bits include making sure to call accum_param manually on any params collected because ignore_derivatives drops them otherwise and breaks gradient(() -> gradient(...), ps). This would require at least one AD rule to get at the underlying Context.

@christiangnrd
Copy link
Contributor Author

@ToucheSir Is there a test for nested gradient calls? If not could you provide a sample use-case for my understanding (and to turn into a test)?

@ToucheSir
Copy link
Member

I'm not aware of any, so I tried coming up with one. Funnily enough, all the examples I could think of either didn't work on master or behaved as if params didn't propagate any gradients in the first place! So perhaps this is an even simpler matter of marking params itself as @non_differentiable.

A couple working examples from my testing which behave exactly the same if params is marked non-diff:

using Flux, LinearAlgebra

x = ones(1, 1)
d = Dense([2.0;;], [3.0])

gradient(() -> sum(d(x)) + sum(p -> 2norm(p), Flux.params(d)), Flux.params(d)).grads
gradient(() -> sum(d(x)) + sum(gradient(() -> sum(d.weight), Flux.params(d))[d.weight]), Flux.params(d))

@christiangnrd
Copy link
Contributor Author

It seems like adding @non_differentiable params(m...) right after defining params fixes the caching issue. I ran the tests locally and they all passed. I'll open a new pull request since it's one line added to the current master and this pr is getting messy.

ToucheSir added a commit that referenced this pull request Aug 30, 2022
Make params non-differentiable (Closes #2040 & #2048)
@christiangnrd
Copy link
Contributor Author

Closing this since issue #2054 supersedes it and was merged!

@christiangnrd christiangnrd deleted the 0.13.5_regression_fix branch August 30, 2022 00:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants