This should be 1 - f, according to the paper. Confusion arose around the effect of the "forget" gate (in LSTM and GRU papers, information is passed through when f is high, but in MGU paper it is the opposite). Variable f from the MGU paper, is effectively 1 - f in Flax (it is the portion that is contributes to short-term response, or n in Flax-speak). From the paper:
In MGU, the forget gate f_t is first generated, and the element-wise product between 1 - f_t and h_{t−1} becomes part of the new hidden state h_t. The portion of h_{t-1} that is "forgotten" (f_t h_{t−1}) is combined with x_t to produce h_bar_t, the short-term response. A portion of h_bar_t (determined again by f_t) form the second part of h_t.
flax/flax/linen/recurrent.py
Line 725 in d59132d
This should be
1 - f, according to the paper. Confusion arose around the effect of the "forget" gate (in LSTM and GRU papers, information is passed through whenfis high, but in MGU paper it is the opposite). Variableffrom the MGU paper, is effectively1 - fin Flax (it is the portion that is contributes to short-term response, ornin Flax-speak). From the paper: