Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent network interface updates/design #1678

Open
1 of 7 tasks
mkschleg opened this issue Jul 26, 2021 · 25 comments
Open
1 of 7 tasks

Recurrent network interface updates/design #1678

mkschleg opened this issue Jul 26, 2021 · 25 comments

Comments

@mkschleg
Copy link
Contributor

mkschleg commented Jul 26, 2021

While we were discussing #1675 and #1671 several improvements/updates to the recurrent network API came up. Instead of taking over #1675, @ToucheSir and myself thought it would be best to separate out the needed improvements into a separate issue so they can be worked on and discussed here. That way we can finish #1675, and move on with the other changes in lock-step.

Any others I'm missing?

@ToucheSir
Copy link
Member

I added the ConvLSTM issue in the top post so we can track that here as well.

@mkschleg
Copy link
Contributor Author

I've started working on the Folded interface. It actually should be pretty easy to add without disrupting the current api, but I haven't thought through how CuDNN fits yet. I'll make a pr after we get #1675 merged, so we can iterate.

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 2, 2021

I separated out CuDNNs from 3d array support. They both influence each other, but I think getting the FoldedRNN's api right needs a bit of iteration first.

@mkschleg
Copy link
Contributor Author

I added Bidirectional RNNs as per the conversation in #1686.

bors bot added a commit that referenced this issue Sep 14, 2021
1686: Adding support for folding RNNs over 3d arrays r=DhairyaLGandhi a=mkschleg

From #1678, adding a Recur like interface for a folded operation with support for 3-dimensional arrays. This is how many users expect RNNs to work if they are familiar with Pytorch and Tensorflow, and there seems to be some desire for support for this feature as per the discussion in #1671 and `@jeremiedb` .  This will also make a push to implementing support for the CuDNN versions of RNNs/GRUs/LSTMs more streamlined as this is the data layout that API expects. 

I did a barebones implementation to add support so we can start iterating on API.

There are several questions that I have lingering with this interface:
- ~Should we support different modes where we return all or only the last hidden state? Is there a better way to do the concat of the hidden states?~
- What kind of tests should we have? Just follow what we currently do for RNNs/LSTMs/GRUs?
- ~For the CPU version, does it make sense not to specialize on the different rnn types? We might be able to take more advantage of BLAS if we specialized on say `Folded{GRU}`.~
- ~Do we want to force the temporal dimension to be the 2nd?~
- ~Do we want this to be stateful? (i.e. allow the user to change what the starting hidden state is rather than state0).~

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Matthew Schlegel <[email protected]>
Co-authored-by: Matthew Schlegel <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
@CarloLucibello
Copy link
Member

Now that #1686 is merged I think we should disambiguate the interface (e.g. a 3d tensor input could be a single time step on batched grey images [width, height, batch_size] or multiple time steps on batched 1d inputs [num_features, batch_size, seq_length]).

I think that starting from next breaking release we should always assume that the input's last dimensions are batch and time, and start introducing a deprecation path.

@DhairyaLGandhi
Copy link
Member

It's not something that needs deprecation yet.

@CarloLucibello
Copy link
Member

If we decide to go with this

we should always assume that the input's last dimensions are batch and time, and start introducing a deprecation path.

at the very least we should immediately update the docs

@ToucheSir
Copy link
Member

👍 for a docs update. I don't think we need to deprecate anything though, because it's not at all clear that Recur is the right struct to hang all this off of long term. What might be nice is if we took this opportunity to create the proposed experimental package/namespace and started prototyping there. My assumption is that most of the back and forth there will be about naming and API design, so being able to change both at a whim while still having something concrete for users to test would be nice.

@dcecchini
Copy link

I could not find the reference to Bidirectional on #1686 so I created a wrapper inspired by the way it is applied in Keras. The PR with more details can be found here.

The errors on the checks were due to the instability of Github this Sunday (could not connect to github.com).

@mkschleg
Copy link
Contributor Author

@ToucheSir @DhairyaLGandhi @CarloLucibello and others. Sorry for the recent static, but this semester was absolutely chaotic/brutal for me.

I've started a RNN repo where we can play with the interface. I think the first thing we need to do is figure out the weirdness in the API that was introduced by #1686 . As discussed, it would not be possible to have convolutional RNNs w/ this interface. I'm thinking the recur struct needs to know some amount of information about its cell type for this to work. It is also possible that we could use traits to accomplish this.

This would also be a good place to test CuDNN paths.

I'm curious to see what people think so we can get the RNN api in a better place. I have some ideas and will hopefully get to working on instantiating them over the next few days.

The current repo just a copy of the current interface, and the tests are also the same (but just for the RNN layers).

@mkschleg
Copy link
Contributor Author

Been playing around w/ the designs in FluxRNNs.jl. The one I've settled on is using traits for the input types and then EllipsesNotation to deal with to looping over the final index.

For each type of input all we need to do is implement forward!(::InputType, m, x) and it should appropriately dispatch and input_size. The only thing we need to think about is if we want there to be failure modes for certain input types and warnings. For example if we have a ConvRNN that expects as a single time step a 2D (for batch = 1) or 3D (for larger batches) array and recieves a 1D array it a nice warning might be better than the dispatch error.

This also brings up some oddness. There are also some other odd hard edges with this design. For example. Say you us a 2D array as input and assume that the dims are in \times timesteps, the current implementation assumes timesteps is actually batch and not roll through. This is solved by documentation to reshape the array as in \times 1 \times timesteps, which is not unreasonable.

@mkschleg
Copy link
Contributor Author

mkschleg commented Dec 26, 2021

My next objective is to actually implement a ConvLSTM to see how it works w/ the interface.

@ToucheSir
Copy link
Member

One thing I didn't think of for #1686 but could help quite a bit with performance is using eachslice for 3D RNN inputs. That also brings up the question of whether the 3D interface shouldn't be a Vector of arrays one instead, as eachslice/ArraysOfArrays/etc. give all the benefits of contiguous memory while keeping the nested array interface. That could also resolve some of the ambiguity around which axis is which, though it would lock the API into having timesteps last (not the end of the world, but there have been occasional requests for having batch last).

@mkschleg
Copy link
Contributor Author

mkschleg commented Dec 28, 2021

Oh! Right, this was what I was looking for at some point. I'll do some comparisons and see. I think eachslice would be much cleaner and we could probably remove the need for EllipsesNotation. Even though that package is awesome... lol.

I have no qualms with setting the assumption that the last dimensions is time, as this makes the most sense to me.

When we consider the types of inputs we may need for a recurrent architecture, each cell will have three options:

  • Just an input (no batching) of the appropriate size: in
  • An input in batches: in x samples
  • An input w/ batch over time: in x samples x timesteps

These three feel appropriate, and I don't think it would be useful to have a utility which can take something that is in x time. All of these naturally place the timesteps as the last dimension. If others want it in another place, they can implement their own recur struct as it still has very few responsibilities and is pretty easy to modify.

The one we might be able to remove would be the first, forcing ppl to shape as in x 1. But this would break a given chain for just vector inputs (say something like chain(Dense(.), RNN(.), Dense(.)) as passing a vector into this model would fail at the RNN. I like having that as an option, but I could be convinced otherwise as well.

I need to read #1790 still to figure out how BiDirection fits within all this, but from some brief skimming it looks like it can be handled outside of recur, which is likely for the best.

@mkschleg
Copy link
Contributor Author

mkschleg commented Jan 22, 2022

Ok. From the above, eachslice is slightly better (but likely just the same) as our view. So I would say that would be a good option. In the repo I'm working in I've switched to this. In the FluxRNNs repo, I've been working on some tools to measure and compare performance of rnns and store/plot the data so I can make changes without having to go back and manually check for regressions. It is inspired what flux does already, but to make it work w/ rnns. We can eventually use some of this in Flux.

Instead of going towards convLSTMs, because I'm still looking for a canonical implementation that I can base things off of, I started playing w/ cudnn paths in this branch. There are some challenges to get the new CuDNN interface from CUDA.jl working w/ our RNNs that I'm still working through.

Right now, CuDNN expects the weights to be in block matrix of size (out, in+out+bias), which we don't do. I implemented a CuRNN which is literally just a simple RNN except we are storing weights in a single block and then using views to access the parts of it. I see a few options that some opinions on would be good:

  1. instead use a cat to combine the weights right before this operation
  2. have a different struct that we return when passing through the gpu function which manages this, maybe something like
struct CuRNNCell{F,A,V,M,S}
  σ::F
  Wi::A # view into W
  Wh::A # view into W
  b::V # view into W
  W::M
  state0::S
end
  1. Change how RNN weights are stored for all RNNs in Flux.

I was looking to see how this was handled in the past (i.e. v0.10.0), and what was done was to use a function CUDNN.set_weights! which doesn't seem to be supported by CUDA.jl.

I think this is a bit bigger of an issue than the conv issue, and we should try and resolve this relatively soon before making too many changes to RNNs.

@ToucheSir
Copy link
Member

Do weights need to be blocked for stacked (multi-layer) cuDNN RNNs as well? This certainly makes the design more challenging, but I like your view idea. Since basically the entire forward pass would require a custom rrule to make use of cuDNN functionality, there's also more room to optimize inside that rrule.

@mkschleg
Copy link
Contributor Author

mkschleg commented Jan 22, 2022

Yes. It is really handy to take a look at the docs for cudnnRNNForward. I don't think there is a centralized "documneter" page for the CUDNN docs, but it is fairly easy to read here. The description params are the same for both the forward and AD. And there is even a handy structure that can be used in-place of the kwargs.

I think the blocking would be interesting. I think we should avoid having a separate cpu/gpu struct and stick w/ the shared structure. Seems like a bad idea to change the structure in a hidden way when going through gpu. This would also give us a path to create custom cpu paths/kernels for various configurations if that is something we want to work on eventually.

I'm gunna ping CUDA.jl to see if there is something like the old set_weights, and if that would be a path worth thinking about for this as well.

@mkschleg
Copy link
Contributor Author

Anyway. No reason to ping them. I was searching for set_weights instead of setweights 🤦‍♂️ . This is in CuArrays.jl. Basically what is happening is the weights are being copied into a blocked cuarray that is sliced into the correct sizes. This may follow from how cudnn allocates its parameters through the cudnnRNNDataDescriptor_t (see here, which RNNDesc seems to be taking after. This type is quite opaque, and would be a c-type defined in cuDNN. I'm still learning how to best read the cuDNN docs, so maybe there is somewhere it is well documented.

In any case. Spelunking through some of Tensorflow's source they also use these structures pretty wholesale afaict. I'm still trying to figure out pytorch's source, but I would be surprised if they didn't also use this.

All in all, I think moving to implementations where we have blocked weights might just be best for the future. This will probably be more flexible in the end for optimizations, and hopefully doesn't have a performance regression.

@ToucheSir
Copy link
Member

My understanding from reading the main C API docs is that the underlying representation of these structures is purposefully kept opaque. If generating a RNN descriptor isn't too expensive, that could always be done on the fly. Caching is also an option like is currently done for some other cuDNN descriptor types.

@mkschleg
Copy link
Contributor Author

That makes sense.

My guess is generating the RNN descriptor on the fly shouldn't be too bad. My only concern is how to manage the weights (as the cudnnRNNDataDescriptor_t allocs/stores its own it seems).

I see a few options:

  1. maybe instead we could just use the descriptor as the "W" and then use cudnnGetRNNWeightParams to get views into the weight space.
  2. If we want to manage our own weights, we could use this to also copy these weights into these views and use caching for the descriptor.
  3. We could ignore the data descriptor for cudnn and instead use the CUDA.CUDNN.cudnnRNNForward w/ our own stored weights (through CuArrays).

I think 3 is pretty sensible, as this is how CUDA.CUDNN is setup. And is what I proposed above. I think the discussion down cudnnRNNDataDescriptor_t might have been a bit of a red haring tbh.

Also, sorry for the dump of info/iteration on my end. Interacting w/ cudnn is really new to me, so it is taking awhile to get up and running.

@mkschleg
Copy link
Contributor Author

mkschleg commented Feb 2, 2022

First steps to block weights in rnns #1855 . Would really like some thoughts before I move forward on this for the other cells. This impl works with cudnn (I've only worked with the forward pass, but I'm assuming it will also work w/ the backward).

The main change is blocking the weights and adding some useful views.

@mkschleg
Copy link
Contributor Author

Was perusing the new Optimisers.jl stateless approach, and was wondering if we might want something similar to that for recurrent cells. This would also be similar to how flax works, I think. I remember having a lot of headaches come from having state be embedded in recur, and actually have a whole stack for inspecting and reseting state to a value other than s0 (for RL things) in arbitrary chains/models.

Lux.jl already does this, so we might be able to take advantage of their implementation of "Recurrence". Although, I'm not sure where the boundaries between Lux and Flux are being drawn. This would only be concerned with the state of the cell, not the parameterization. I'm also not sure how well this would interact with the chain interface, and could be impossible.

@ToucheSir
Copy link
Member

I was thinking much the same thing! Never got around to it though, but now that we have https://github.com/FluxML/Fluxperimental.jl/ this should be easier to prototype.

I'm also not sure how well this would interact with the chain interface

My rough idea would be:

function Flux.apply(c::Chain, x)
  l1, y1 = apply(c[1], x)
  l2, y2 = apply(c[1], y1)
  ...
  return Chain(l1, l2, ...), yN
end

Instead of updating an externally passed in piece of state, we return an updated layer object. For layers without any auxiliary, non-parameter state, a fallback apply(m, x) = m, m(x) should be sufficient. This is similar to how Accessors.jl works, and indeed we might want to make this interface with that library once we've fleshed out a design.

@darsnack
Copy link
Member

And note that an apply interface such as this (or Optimisers.jl) is about explicit (non-)mutation which is not the same as separation of parameters/state from the model (Lux does both).

For examples of how these are different in Jax, you can compare Flax vs Equinox. The latter is closer to Flux's direction.

@mkschleg
Copy link
Contributor Author

Was a pretty easy interface to implement in Fluxperimental. Pull request here: FluxML/Fluxperimental.jl#5. Once we are happy with the basic interface we can start fleshing out the functionality to stateful structures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants