@@ -13,25 +13,22 @@ class ConstantLengthEmbedding(nn.Module):
1313
1414 def __init__ (
1515 self ,
16- n_time_steps : int ,
1716 n_input_state_variables : int ,
1817 n_output_state_variables : int ,
1918 bias : bool = True ,
2019 ):
2120 """
22- :param n_time_steps: number of time steps in the input signal and embedding
2321 :param n_input_state_variables: number of state input variables
2422 :param n_output_state_variables: number of states to produce
2523 :param bias: if True bias will be used in linear operation
2624 """
2725 super (ConstantLengthEmbedding , self ).__init__ ()
2826
29- self .n_input_time_steps = n_time_steps
3027 self .n_input_state_variables = n_input_state_variables
3128 self .n_output_state_variables = n_output_state_variables
3229
3330 self .up_projection = nn .Linear (
34- in_features = self .n_state_variables , out_features = self .n_output_state_variables , bias = bias
31+ in_features = self .n_input_state_variables , out_features = self .n_output_state_variables , bias = bias
3532 )
3633
3734 def forward (self , inputs : Tensor ) -> Tensor :
@@ -43,7 +40,7 @@ class ShorteningCausalEmbedding(nn.Module):
4340 Module converting time series with shape (batch_size, n_input_time_steps, n_input_state_variables) into time series
4441 with shape (batch_size, n_output_time_steps, n_output_state_variables) using two learned linear transformations.
4542
46- The module can be used as embedding for transformer models, which uprojects the state variables and makes the time
43+ The module can be used as embedding for transformer models, which up-projects the state variables and makes the time
4744 series shorter. The module is always causal.
4845
4946 It works by firstly applying single linear to all time-steps up-projecting the dimensionality. Later convolution
@@ -57,13 +54,15 @@ def __init__(
5754 n_output_time_steps : int ,
5855 n_input_state_variables : int ,
5956 n_output_state_variables : int ,
57+ conv_groups : int = 1 ,
6058 bias : bool = True ,
6159 ):
6260 """
6361 :param n_input_time_steps: number of time steps in the input signal
6462 :param n_output_time_steps: number of time steps to produce after linear operation
6563 :param n_input_state_variables: number of state input variables
6664 :param n_output_state_variables: number of states to produce
65+ :param conv_groups: number of groups in convolutional layer
6766 :param bias: if True bias will be used in linear operation
6867 """
6968 if n_input_time_steps % n_output_time_steps != 0 :
@@ -75,6 +74,7 @@ def __init__(
7574 self .n_output_time_steps = n_output_time_steps
7675 self .n_input_state_variables = n_input_state_variables
7776 self .n_output_state_variables = n_output_state_variables
77+ self .conv_groups = conv_groups
7878
7979 self .up_projection = nn .Linear (
8080 in_features = self .n_input_state_variables , out_features = self .n_output_state_variables , bias = bias
@@ -85,6 +85,7 @@ def __init__(
8585 out_channels = self .n_output_state_variables ,
8686 kernel_size = self .n_input_time_steps // self .n_output_time_steps ,
8787 stride = self .n_input_time_steps // self .n_output_time_steps ,
88+ groups = conv_groups ,
8889 bias = bias ,
8990 )
9091
0 commit comments