@@ -80,11 +80,7 @@ function FastRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=t
8080 integration_mode:: Symbol = :addition,
8181 independent_recurrence:: Bool = false )
8282 weight_ih = init_kernel(hidden_size, input_size)
83- if independent_recurrence
84- weight_hh = vec(init_recurrent_kernel(hidden_size))
85- else
86- weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
87- end
83+ weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
8884 bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1 ))
8985 bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1 ))
9086 T = eltype(weight_ih)
@@ -286,11 +282,7 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast;
286282 integration_mode:: Symbol = :addition,
287283 independent_recurrence:: Bool = false )
288284 weight_ih = init_kernel(hidden_size, input_size)
289- if independent_recurrence
290- weight_hh = vec(init_recurrent_kernel(hidden_size))
291- else
292- weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
293- end
285+ weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
294286 bias_alt = create_bias(weight_ih, alt_bias, 2 * size(weight_ih, 1 ))
295287 bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1 ))
296288 bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1 ))
0 commit comments