@@ -100,7 +100,6 @@ def seasonality_model(
100100 "bp,pt->bt" , theta [:, config_per_harmonic : 2 * config_per_harmonic ], forecast_sin_template
101101 )
102102 forecast = forecast_harmonics_sin + forecast_harmonics_cos
103-
104103 return backcast , forecast
105104
106105
@@ -123,9 +122,14 @@ class GenericBlock(tf.keras.layers.Layer):
123122 """
124123
125124 def __init__ (
126- self , train_sequence_length : int , predict_sequence_length : int , hidden_size : int , n_block_layers : int = 4
125+ self ,
126+ train_sequence_length : int ,
127+ predict_sequence_length : int ,
128+ hidden_size : int ,
129+ n_block_layers : int = 4 ,
130+ ** kwargs
127131 ):
128- super (GenericBlock , self ).__init__ ()
132+ super (GenericBlock , self ).__init__ (** kwargs )
129133 self .train_sequence_length = train_sequence_length
130134 self .predict_sequence_length = predict_sequence_length
131135 self .hidden_size = hidden_size
@@ -139,9 +143,9 @@ def build(self, input_shape: Tuple[Optional[int], ...]):
139143 input_shape : Tuple[Optional[int], ...]
140144 Shape of the input tensor
141145 """
146+ super (GenericBlock , self ).build (input_shape )
142147 self .layers = [Dense (self .hidden_size , activation = "relu" ) for _ in range (self .n_block_layers )]
143148 self .theta = Dense (self .train_sequence_length + self .predict_sequence_length , use_bias = False , activation = None )
144- super (GenericBlock , self ).build (input_shape )
145149
146150 def call (self , inputs : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
147151 """Compute the output of the Generic Block.
@@ -164,6 +168,24 @@ def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
164168 x = self .theta (x )
165169 return generic_model (x , tf .range (self .train_sequence_length ), tf .range (self .predict_sequence_length ))
166170
171+ def compute_output_shape (self , input_shape ):
172+ batch_size = input_shape [0 ]
173+ backcast_shape = (batch_size , self .train_sequence_length )
174+ forecast_shape = (batch_size , self .predict_sequence_length )
175+ return (backcast_shape , forecast_shape )
176+
177+ def get_config (self ):
178+ config = super ().get_config ()
179+ config .update (
180+ {
181+ "train_sequence_length" : self .train_sequence_length ,
182+ "predict_sequence_length" : self .predict_sequence_length ,
183+ "hidden_size" : self .hidden_size ,
184+ "n_block_layers" : self .n_block_layers ,
185+ }
186+ )
187+ return config
188+
167189
168190class TrendBlock (tf .keras .layers .Layer ):
169191 """Trend block that learns trend patterns using polynomial basis functions.
@@ -192,8 +214,9 @@ def __init__(
192214 hidden_size : int ,
193215 n_block_layers : int = 4 ,
194216 polynomial_term : int = 2 ,
217+ ** kwargs
195218 ):
196- super ().__init__ ()
219+ super ().__init__ (** kwargs )
197220
198221 self .train_sequence_length = train_sequence_length
199222 self .predict_sequence_length = predict_sequence_length
@@ -226,12 +249,10 @@ def build(self, input_shape: Tuple[Optional[int], ...]):
226249 input_shape : Tuple[Optional[int], ...]
227250 Shape of the input tensor
228251 """
229-
252+ super (). build ( input_shape )
230253 self .layers = [Dense (self .hidden_size , activation = "relu" ) for _ in range (self .n_block_layers )]
231254 self .theta = Dense (2 * self .polynomial_size , use_bias = False , activation = None )
232255
233- super ().build (input_shape )
234-
235256 def call (self , inputs : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
236257 """Compute the output of the Trend Block.
237258
@@ -254,14 +275,16 @@ def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
254275 return trend_model (x , self .backcast_time , self .forecast_time , self .polynomial_size )
255276
256277 def compute_output_shape (self , input_shape ):
257- return [( input_shape [0 ], self .train_sequence_length ), (input_shape [0 ], self .predict_sequence_length )]
278+ return (( input_shape [0 ], self .train_sequence_length ), (input_shape [0 ], self .predict_sequence_length ))
258279
259280
260281class SeasonalityBlock (tf .keras .layers .Layer ):
261282 """Seasonality block"""
262283
263- def __init__ (self , train_sequence_length , predict_sequence_length , hidden_size , n_block_layers = 4 , num_harmonics = 1 ):
264- super ().__init__ ()
284+ def __init__ (
285+ self , train_sequence_length , predict_sequence_length , hidden_size , n_block_layers = 4 , num_harmonics = 1 , ** kwargs
286+ ):
287+ super ().__init__ (** kwargs )
265288 self .train_sequence_length = train_sequence_length
266289 self .predict_sequence_length = predict_sequence_length
267290 self .hidden_size = hidden_size
@@ -300,6 +323,7 @@ def __init__(self, train_sequence_length, predict_sequence_length, hidden_size,
300323 self .forecast_sin_template = tf .transpose (tf .sin (self .forecast_grid ))
301324
302325 def build (self , input_shape : Tuple [Optional [int ], ...]):
326+ super ().build (input_shape )
303327 self .layers = [Dense (self .hidden_size , activation = "relu" ) for _ in range (self .n_block_layers )]
304328 self .theta = Dense (self .theta_size , use_bias = False , activation = None )
305329
@@ -336,17 +360,21 @@ def call(self, inputs):
336360 self .forecast_sin_template ,
337361 )
338362
339-
340- class ZerosLayer (tf .keras .layers .Layer ):
341- """Layer for creating zeros tensor with proper shape"""
342-
343- def __init__ (self , predict_length , ** kwargs ):
344- super (ZerosLayer , self ).__init__ (** kwargs )
345- self .predict_length = predict_length
346-
347- def call (self , x ):
348- batch_size = tf .shape (x )[0 ]
349- return tf .zeros ([batch_size , self .predict_length ], dtype = tf .float32 )
350-
351363 def compute_output_shape (self , input_shape ):
352- return (input_shape [0 ], self .predict_length )
364+ batch_size = input_shape [0 ]
365+ backcast_shape = (batch_size , self .train_sequence_length )
366+ forecast_shape = (batch_size , self .predict_sequence_length )
367+ return (backcast_shape , forecast_shape )
368+
369+ def get_config (self ):
370+ config = super ().get_config ()
371+ config .update (
372+ {
373+ "train_sequence_length" : self .train_sequence_length ,
374+ "predict_sequence_length" : self .predict_sequence_length ,
375+ "hidden_size" : self .hidden_size ,
376+ "n_block_layers" : self .n_block_layers ,
377+ "num_harmonics" : self .num_harmonics ,
378+ }
379+ )
380+ return config
0 commit comments