Description
We have a usecase where we're implementing DeepSpeech2 model in Flax. DeepSpeech2 is an older speech recognition model based on RNN style layers (Bi-LSTMs used commonly)
flax doesn't have Bi-LSTMs so we hacked a version of our own based on existing RNNCell but I messed up handling of paddings and this caused a long debugging loop
eventually we found a flax BiLSTM layer folks implemented that flips sequences for the reverse direction to run LSTM and then flips the output which worked for our usecase involving padded inputs.
Overall it feels current RNN layers in flax are very bare bones as compared to pytorch which does RNNs really well, it'd be amazing to have full-fledged Bi-LSTM, GRU, RNN cells ready to go , currently folks would even have to write up their own wrapper that uses nn.scan
around the default flax cell primitives for RNN