Skip to content

Commit cf4b9d3

Browse files
Merge pull request #161 from MartinuzziFrancesco/fm/ir
refac: reduce code reuse for independent recurrence
2 parents 814ecdf + 0338cb2 commit cf4b9d3

24 files changed

+38
-135
lines changed

src/cells/atr_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@ function ATRCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7373
bias::Bool=true, recurrent_bias::Bool=true,
7474
integration_mode::Symbol=:addition, independent_recurrence::Bool=false)
7575
weight_ih = init_kernel(hidden_size, input_size)
76-
if independent_recurrence
77-
weight_hh = vec(init_recurrent_kernel(hidden_size))
78-
else
79-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
80-
end
76+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
8177
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8278
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8379
if integration_mode == :addition

src/cells/cfn_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,7 @@ function CFNCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7575
bias::Bool=true, recurrent_bias::Bool=true,
7676
integration_mode::Symbol=:addition, independent_recurrence::Bool=false)
7777
weight_ih = init_kernel(hidden_size * 3, input_size)
78-
if independent_recurrence
79-
weight_hh = vec(init_recurrent_kernel(hidden_size * 2))
80-
else
81-
weight_hh = init_recurrent_kernel(hidden_size * 2, hidden_size)
82-
end
78+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size, 2)
8379
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8480
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8581
integration_fn = _integration_fn(integration_mode)

src/cells/cornn_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,7 @@ function coRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int},
8787
cell_bias::Bool=true, integration_mode::Symbol=:addition,
8888
independent_recurrence::Bool=false)
8989
weight_ih = init_kernel(hidden_size, input_size)
90-
if independent_recurrence
91-
weight_hh = vec(init_recurrent_kernel(hidden_size))
92-
else
93-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
94-
end
90+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
9591
weight_ch = init_cell_kernel(hidden_size, hidden_size)
9692
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
9793
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))

src/cells/fastrnn_cell.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,7 @@ function FastRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=t
8080
integration_mode::Symbol=:addition,
8181
independent_recurrence::Bool=false)
8282
weight_ih = init_kernel(hidden_size, input_size)
83-
if independent_recurrence
84-
weight_hh = vec(init_recurrent_kernel(hidden_size))
85-
else
86-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
87-
end
83+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
8884
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8985
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
9086
T = eltype(weight_ih)
@@ -286,11 +282,7 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast;
286282
integration_mode::Symbol=:addition,
287283
independent_recurrence::Bool=false)
288284
weight_ih = init_kernel(hidden_size, input_size)
289-
if independent_recurrence
290-
weight_hh = vec(init_recurrent_kernel(hidden_size))
291-
else
292-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
293-
end
285+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
294286
bias_alt = create_bias(weight_ih, alt_bias, 2 * size(weight_ih, 1))
295287
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
296288
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))

src/cells/janet_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,7 @@ function JANETCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7979
integration_mode::Symbol=:addition,
8080
independent_recurrence::Bool=false, beta_value::AbstractFloat=1.0f0)
8181
weight_ih = init_kernel(hidden_size * 2, input_size)
82-
if independent_recurrence
83-
weight_hh = vec(init_recurrent_kernel(2 * hidden_size))
84-
else
85-
weight_hh = init_recurrent_kernel(hidden_size * 2, hidden_size)
86-
end
82+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size, 2)
8783
beta = fill(eltype(weight_ih)(beta_value), hidden_size)
8884
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8985
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))

src/cells/lem_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,7 @@ function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0f0
8888
cell_bias::Bool=true, integration_mode::Symbol=:addition,
8989
independent_recurrence::Bool=false)
9090
weight_ih = init_kernel(hidden_size * 4, input_size)
91-
if independent_recurrence
92-
weight_hh = vec(init_recurrent_kernel(hidden_size * 3))
93-
else
94-
weight_hh = init_recurrent_kernel(hidden_size * 3, hidden_size)
95-
end
91+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size, 3)
9692
weight_ch = init_cell_kernel(hidden_size, hidden_size)
9793
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
9894
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))

src/cells/lightru_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ function LightRUCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7272
integration_mode::Symbol=:addition,
7373
independent_recurrence::Bool=false)
7474
weight_ih = init_kernel(2 * hidden_size, input_size)
75-
if independent_recurrence
76-
weight_hh = vec(init_recurrent_kernel(hidden_size))
77-
else
78-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
79-
end
75+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
8076
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8177
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8278
integration_fn = _integration_fn(integration_mode)

src/cells/ligru_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,7 @@ function LiGRUCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7575
integration_mode::Symbol=:addition,
7676
independent_recurrence::Bool=false)
7777
weight_ih = init_kernel(2 * hidden_size, input_size)
78-
if independent_recurrence
79-
weight_hh = vec(init_recurrent_kernel(2 * hidden_size))
80-
else
81-
weight_hh = init_recurrent_kernel(2 * hidden_size, hidden_size)
82-
end
78+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size, 2)
8379
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8480
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8581
integration_fn = _integration_fn(integration_mode)

src/cells/mgu_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@ function MGUCell((input_size, hidden_size)::Pair{<:Int, <:Int};
7373
integration_mode::Symbol=:addition,
7474
independent_recurrence::Bool=false)
7575
weight_ih = init_kernel(2 * hidden_size, input_size)
76-
if independent_recurrence
77-
weight_hh = vec(init_recurrent_kernel(2 * hidden_size))
78-
else
79-
weight_hh = init_recurrent_kernel(2 * hidden_size, hidden_size)
80-
end
76+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size, 2)
8177
bias_ih = create_bias(weight_ih, bias, size(weight_ih, 1))
8278
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))
8379
integration_fn = _integration_fn(integration_mode)

src/cells/minimalrnn_cell.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,7 @@ function MinimalRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int};
8080
integration_mode::Symbol=:addition,
8181
independent_recurrence::Bool=false)
8282
weight_ih = init_encoder_kernel(hidden_size, input_size)
83-
if independent_recurrence
84-
weight_hh = vec(init_recurrent_kernel(hidden_size))
85-
else
86-
weight_hh = init_recurrent_kernel(hidden_size, hidden_size)
87-
end
83+
weight_hh = _indrec_matrix(independent_recurrence, init_recurrent_kernel, hidden_size)
8884
weight_mm = init_memory_kernel(hidden_size, hidden_size)
8985
bias_ih = create_bias(weight_ih, encoder_bias, size(weight_ih, 1))
9086
bias_hh = create_bias(weight_hh, recurrent_bias, size(weight_hh, 1))

0 commit comments

Comments
 (0)