@@ -83,15 +83,7 @@ function AntisymmetricRNNCell(
8383 bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1 ))
8484 bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1 ))
8585 T = eltype(weight_ih)
86- if integration_mode == :addition
87- integration_fn = add_projections
88- elseif integration_mode == :multiplicative_integration
89- integration_fn = mul_projections
90- else
91- throw(ArgumentError(
92- " integration_mode must be :addition or :multiplicative_integration; got $integration_mode "
93- ))
94- end
86+ integration_fn = _integration_fn(integration_mode)
9587 return AntisymmetricRNNCell(activation, weight_ih, weight_hh, bias_ih,
9688 bias_hh, T(epsilon), T(gamma), integration_fn)
9789end
@@ -280,15 +272,7 @@ function GatedAntisymmetricRNNCell(
280272 bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1 ))
281273 bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1 ))
282274 T = eltype(weight_ih)
283- if integration_mode == :addition
284- integration_fn = add_projections
285- elseif integration_mode == :multiplicative_integration
286- integration_fn = mul_projections
287- else
288- throw(ArgumentError(
289- " integration_mode must be :addition or :multiplicative_integration; got $integration_mode "
290- ))
291- end
275+ integration_fn = _integration_fn(integration_mode)
292276 return GatedAntisymmetricRNNCell(
293277 weight_ih, weight_hh, bias_ih, bias_hh, T(epsilon), T(gamma), integration_fn)
294278end
0 commit comments