Skip to content

Commit 814ecdf

Browse files
Merge pull request #160 from MartinuzziFrancesco/fm/simple
refac: simplify integration_fn
2 parents 7ac7dbe + ffc8359 commit 814ecdf

29 files changed

+58
-291
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecurrentLayers"
22
uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c"
33
authors = ["Francesco Martinuzzi"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/base_functions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323
function add_bias!(weight_inporstate::AbstractMatrix, bias::AbstractVector)
2424
@assert size(weight_inporstate, 1) == length(bias)
2525
@inbounds for jdx in axes(weight_inporstate, 2), idx in axes(weight_inporstate, 1)
26+
2627
weight_inporstate[idx, jdx] += bias[idx]
2728
end
2829
return weight_inporstate

src/cells/antisymmetricrnn_cell.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,7 @@ function AntisymmetricRNNCell(
8383
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8484
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8585
T = eltype(weight_ih)
86-
if integration_mode == :addition
87-
integration_fn = add_projections
88-
elseif integration_mode == :multiplicative_integration
89-
integration_fn = mul_projections
90-
else
91-
throw(ArgumentError(
92-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
93-
))
94-
end
86+
integration_fn = _integration_fn(integration_mode)
9587
return AntisymmetricRNNCell(activation, weight_ih, weight_hh, bias_ih,
9688
bias_hh, T(epsilon), T(gamma), integration_fn)
9789
end
@@ -280,15 +272,7 @@ function GatedAntisymmetricRNNCell(
280272
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
281273
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
282274
T = eltype(weight_ih)
283-
if integration_mode == :addition
284-
integration_fn = add_projections
285-
elseif integration_mode == :multiplicative_integration
286-
integration_fn = mul_projections
287-
else
288-
throw(ArgumentError(
289-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
290-
))
291-
end
275+
integration_fn = _integration_fn(integration_mode)
292276
return GatedAntisymmetricRNNCell(
293277
weight_ih, weight_hh, bias_ih, bias_hh, T(epsilon), T(gamma), integration_fn)
294278
end

src/cells/br_cell.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,7 @@ function BRCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7979
if !independent_recurrence
8080
@warn "independent_recurrence defaults to true in BRCell"
8181
end
82-
if integration_mode == :addition
83-
integration_fn = add_projections
84-
elseif integration_mode == :multiplicative_integration
85-
integration_fn = mul_projections
86-
else
87-
throw(ArgumentError(
88-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
89-
))
90-
end
82+
integration_fn = _integration_fn(integration_mode)
9183
return BRCell(weight_ih, weight_hh, bias_ih, bias_hh, integration_fn)
9284
end
9385

@@ -278,15 +270,7 @@ function NBRCell((input_size, hidden_size)::Pair{<:Int, <:Int};
278270
if independent_recurrence
279271
@warn "independent_recurrence defaults to false in NBRCell"
280272
end
281-
if integration_mode == :addition
282-
integration_fn = add_projections
283-
elseif integration_mode == :multiplicative_integration
284-
integration_fn = mul_projections
285-
else
286-
throw(ArgumentError(
287-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
288-
))
289-
end
273+
integration_fn = _integration_fn(integration_mode)
290274
return NBRCell(weight_ih, weight_hh, bias_ih, bias_hh, integration_fn)
291275
end
292276

src/cells/cfn_cell.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,7 @@ function CFNCell((input_size, hidden_size)::Pair{<:Int, <:Int};
8282
end
8383
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8484
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
85-
if integration_mode == :addition
86-
integration_fn = add_projections
87-
elseif integration_mode == :multiplicative_integration
88-
integration_fn = mul_projections
89-
else
90-
throw(ArgumentError(
91-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
92-
))
93-
end
85+
integration_fn = _integration_fn(integration_mode)
9486
return CFNCell(weight_ih, weight_hh, bias_ih, bias_hh, integration_fn)
9587
end
9688

src/cells/cornn_cell.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,7 @@ function coRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int},
9797
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
9898
bias_ch = create_bias(weight_ch, cell_bias, size(weight_ch, 1))
9999
T = eltype(weight_ih)
100-
if integration_mode == :addition
101-
integration_fn = add_projections
102-
elseif integration_mode == :multiplicative_integration
103-
integration_fn = mul_projections
104-
else
105-
throw(ArgumentError(
106-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
107-
))
108-
end
100+
integration_fn = _integration_fn(integration_mode)
109101
return coRNNCell(weight_ih, weight_hh, weight_ch, bias_ih, bias_hh, bias_ch,
110102
integration_fn, T(dt), T(gamma), T(epsilon))
111103
end

src/cells/fastrnn_cell.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,7 @@ function FastRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=t
9090
T = eltype(weight_ih)
9191
alpha = T(init_alpha) .* ones(T, 1)
9292
beta = T(init_beta) .* ones(T, 1)
93-
if integration_mode == :addition
94-
integration_fn = add_projections
95-
elseif integration_mode == :multiplicative_integration
96-
integration_fn = mul_projections
97-
else
98-
throw(ArgumentError(
99-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
100-
))
101-
end
93+
integration_fn = _integration_fn(integration_mode)
10294
return FastRNNCell(weight_ih, weight_hh, bias_ih, bias_hh, integration_fn,
10395
alpha, beta, activation)
10496
end
@@ -305,15 +297,7 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast;
305297
T = eltype(weight_ih)
306298
zeta = T(init_zeta) .* ones(T, 1)
307299
nu = T(init_nu) .* ones(T, 1)
308-
if integration_mode == :addition
309-
integration_fn = add_projections
310-
elseif integration_mode == :multiplicative_integration
311-
integration_fn = mul_projections
312-
else
313-
throw(ArgumentError(
314-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
315-
))
316-
end
300+
integration_fn = _integration_fn(integration_mode)
317301
return FastGRNNCell(weight_ih, weight_hh, bias_ih, bias_hh, bias_alt, integration_fn,
318302
zeta, nu, activation)
319303
end

src/cells/indrnn_cell.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,7 @@ function IndRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=re
7979
weight_hh = vec(init_recurrent_kernel(hidden_size))
8080
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8181
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
82-
if integration_mode == :addition
83-
integration_fn = add_projections
84-
elseif integration_mode == :multiplicative_integration
85-
integration_fn = mul_projections
86-
else
87-
throw(ArgumentError(
88-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
89-
))
90-
end
82+
integration_fn = _integration_fn(integration_mode)
9183
return IndRNNCell(activation, weight_ih, weight_hh, bias_ih, bias_hh, integration_fn)
9284
end
9385

src/cells/janet_cell.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,7 @@ function JANETCell((input_size, hidden_size)::Pair{<:Int, <:Int};
8787
beta = fill(eltype(weight_ih)(beta_value), hidden_size)
8888
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8989
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
90-
if integration_mode == :addition
91-
integration_fn = add_projections
92-
elseif integration_mode == :multiplicative_integration
93-
integration_fn = mul_projections
94-
else
95-
throw(ArgumentError(
96-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
97-
))
98-
end
90+
integration_fn = _integration_fn(integration_mode)
9991
return JANETCell(weight_ih, weight_hh, bias_ih, bias_hh, integration_fn, beta)
10092
end
10193

src/cells/lem_cell.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,7 @@ function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0f0
9797
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
9898
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
9999
bias_ch = create_bias(weight_ch, cell_bias, size(weight_ch, 1))
100-
if integration_mode == :addition
101-
integration_fn = add_projections
102-
elseif integration_mode == :multiplicative_integration
103-
integration_fn = mul_projections
104-
else
105-
throw(ArgumentError(
106-
"integration_mode must be :addition or :multiplicative_integration; got $integration_mode"
107-
))
108-
end
100+
integration_fn = _integration_fn(integration_mode)
109101
return LEMCell(weight_ih, weight_hh, weight_ch, bias_ih, bias_hh, bias_ch,
110102
integration_fn, eltype(weight_ih)(dt))
111103
end

0 commit comments

Comments
 (0)