Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On some recent Flux experiences #2171

Open
mcabbott opened this issue Feb 1, 2023 · 4 comments
Open

On some recent Flux experiences #2171

mcabbott opened this issue Feb 1, 2023 · 4 comments

Comments

@mcabbott
Copy link
Member

mcabbott commented Feb 1, 2023

Apparently @owainkenwayucl was trying out Flux, and @giordano was helping him out.

https://twitter.com/owainkenway/status/1620771011863121921

https://github.com/owainkenwayucl/JuliaML/blob/main/Fashion/simple.jl

Edit: now at https://www.youtube.com/watch?v=Yd1JkPljpbY

It's useful to see what problems newcomers run into. Especially people new to Julia.

  • Scope issues like global a_sum_ = a_sum_ + 1 are weird. Flux's tutorials tend to define many functions to put everything in local scope... maybe too many... but for this common use, perhaps Flux ought to have an accuracy function instead of every tutorial rolling its own?

  • Wrong array dimensions give pretty confusing errors. Perhaps Flux layers should catch more of them, instead of waiting for * etc. Some made-up examples (but examples from the wild might be different):

julia> Conv((3,3),3=>4)(rand(10,10,3))  # reasonable to try, maybe it should just work
ERROR: DimensionMismatch: Rank of x and w must match! (3 vs. 4)

julia> Conv((3,3),3=>4)(rand(10,10,2,1))  # error could print out which Conv layer, and size of input, etc.
ERROR: DimensionMismatch: Input channels must match! (2 vs. 3)

julia> Dense(2,3)(rand(4))  # error could be from Dense, there is no matrix called A in user code
ERROR: DimensionMismatch: second dimension of A, 2, does not match length of x, 4
  • Perhaps we can encourage more use of outputsize for understanding array sizes. There could be some way to get an overview of the whole model, like this: https://flax.readthedocs.io/en/latest/getting_started.html#view-model-layers . The default show doesn't know the input size, at present, so it can't tell you all of this. One idea would be to give Chain a mutable field in which to store the most recent input size?
@skyleaworlder
Copy link
Contributor

skyleaworlder commented Feb 4, 2023

It seems that storing input size is a necessary way. Chain itself is a good place for both points: clear responsibilities and uncomplicated design.

In flax, all operators push their input & output size to a list (call_info_stack) which is included by a global thread local object (_DynamicContext). If tabulate called, flax would go to fetch the last record from call_info_stack and format its properties. This is another choice.

Flax's way introduces the complexity of thread management, but it provides more convenience to sampling or profiling work.

@ToucheSir
Copy link
Member

I would like to avoid all the implicit state Flax uses if possible. It makes reasoning about errors more difficult and introduces this temporal dimension to input size handling I'm not sure we want to deal with. If you have layers pre-shape inferred and post-shape inferred, what happens when you pass a new input shape at each stage? How do you know which stage a layer or a composite of layers is in. What happens when you compose pre- and post-shape inference layers? etc.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 4, 2023

What both show and outputsize presently avoid is any need to understand how Chain/Parallel/PairwiseFusion/etc work. There is no layer API to support them. The ways I thought about writing a tabulate-like function need something like a re-implementation of such layers, in order to simultaneously print their contents and trace sizes into their constituents. Could that be avoided?

One quick idea would be to duplicate Nil as NilPrint, and then have @layer always define something like (l::Dense)(x::AbstractArray{NilPrint}) = (println(l, " ... ", size(x)); @invoke l(x::AbstractArray)). Then some trace(model, size) would print the sizes in execution order, without thinking about Functors/children at all.

@mcabbott
Copy link
Member Author

Not the same experience, but see here for how weird the model zoo's loss & accuracy accumulation functions look. Would be nice to fix that somehow.

https://stackoverflow.com/questions/75921783/accuracy-and-gradient-update-not-within-the-same-training-loop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants