Skip to content

better support for RNN layers required #2170

Open
@sourabh2k15

Description

@sourabh2k15

We have a usecase where we're implementing DeepSpeech2 model in Flax. DeepSpeech2 is an older speech recognition model based on RNN style layers (Bi-LSTMs used commonly)

flax doesn't have Bi-LSTMs so we hacked a version of our own based on existing RNNCell but I messed up handling of paddings and this caused a long debugging loop

eventually we found a flax BiLSTM layer folks implemented that flips sequences for the reverse direction to run LSTM and then flips the output which worked for our usecase involving padded inputs.

Overall it feels current RNN layers in flax are very bare bones as compared to pytorch which does RNNs really well, it'd be amazing to have full-fledged Bi-LSTM, GRU, RNN cells ready to go , currently folks would even have to write up their own wrapper that uses nn.scan around the default flax cell primitives for RNN

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions