Skip to content

Commit 829416f

Browse files
authored
Merge pull request #55 from cyber-physical-systems-group/modules/groups-in-embedding
Modules/groups in embedding
2 parents d565ff8 + d2d57db commit 829416f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

pydentification/models/networks/transformer/embedding.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,22 @@ class ConstantLengthEmbedding(nn.Module):
1313

1414
def __init__(
1515
self,
16-
n_time_steps: int,
1716
n_input_state_variables: int,
1817
n_output_state_variables: int,
1918
bias: bool = True,
2019
):
2120
"""
22-
:param n_time_steps: number of time steps in the input signal and embedding
2321
:param n_input_state_variables: number of state input variables
2422
:param n_output_state_variables: number of states to produce
2523
:param bias: if True bias will be used in linear operation
2624
"""
2725
super(ConstantLengthEmbedding, self).__init__()
2826

29-
self.n_input_time_steps = n_time_steps
3027
self.n_input_state_variables = n_input_state_variables
3128
self.n_output_state_variables = n_output_state_variables
3229

3330
self.up_projection = nn.Linear(
34-
in_features=self.n_state_variables, out_features=self.n_output_state_variables, bias=bias
31+
in_features=self.n_input_state_variables, out_features=self.n_output_state_variables, bias=bias
3532
)
3633

3734
def forward(self, inputs: Tensor) -> Tensor:
@@ -43,7 +40,7 @@ class ShorteningCausalEmbedding(nn.Module):
4340
Module converting time series with shape (batch_size, n_input_time_steps, n_input_state_variables) into time series
4441
with shape (batch_size, n_output_time_steps, n_output_state_variables) using two learned linear transformations.
4542
46-
The module can be used as embedding for transformer models, which uprojects the state variables and makes the time
43+
The module can be used as embedding for transformer models, which up-projects the state variables and makes the time
4744
series shorter. The module is always causal.
4845
4946
It works by firstly applying single linear to all time-steps up-projecting the dimensionality. Later convolution
@@ -57,13 +54,15 @@ def __init__(
5754
n_output_time_steps: int,
5855
n_input_state_variables: int,
5956
n_output_state_variables: int,
57+
conv_groups: int = 1,
6058
bias: bool = True,
6159
):
6260
"""
6361
:param n_input_time_steps: number of time steps in the input signal
6462
:param n_output_time_steps: number of time steps to produce after linear operation
6563
:param n_input_state_variables: number of state input variables
6664
:param n_output_state_variables: number of states to produce
65+
:param conv_groups: number of groups in convolutional layer
6766
:param bias: if True bias will be used in linear operation
6867
"""
6968
if n_input_time_steps % n_output_time_steps != 0:
@@ -75,6 +74,7 @@ def __init__(
7574
self.n_output_time_steps = n_output_time_steps
7675
self.n_input_state_variables = n_input_state_variables
7776
self.n_output_state_variables = n_output_state_variables
77+
self.conv_groups = conv_groups
7878

7979
self.up_projection = nn.Linear(
8080
in_features=self.n_input_state_variables, out_features=self.n_output_state_variables, bias=bias
@@ -85,6 +85,7 @@ def __init__(
8585
out_channels=self.n_output_state_variables,
8686
kernel_size=self.n_input_time_steps // self.n_output_time_steps,
8787
stride=self.n_input_time_steps // self.n_output_time_steps,
88+
groups=conv_groups,
8889
bias=bias,
8990
)
9091

0 commit comments

Comments
 (0)