Skip to content

Commit 5d50872

Browse files
authored
fix(layer): standard latest keras layer
1 parent bad1cfa commit 5d50872

File tree

21 files changed

+338
-126
lines changed

21 files changed

+338
-126
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
poetry run python -m pip install pip -U
5959
poetry install --no-interaction --no-root
6060
poetry run python -m pip install tensorflow==${{ matrix.tf-version }}
61-
poetry run python -m pip install matplotlib
61+
poetry run python -m pip install matplotlib numpy==1.26.0
6262
6363
- name: Run unittest
6464
shell: bash

docker/Dockerfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
From tensorflow/tensorflow:2.16.1-gpu
22

33
RUN apt-get update
4-
RUN apt-get install -y libgl1-mesa-dev wget vim python3.9
4+
RUN apt-get install -y libgl1-mesa-dev wget vim
55

66
RUN pip install --no-cache-dir tfts
77

8+
EXPOSE 8888
9+
810
# Set the default command to python3.
911
CMD ["python3"]

docs/source/quick-start.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,11 @@ Run with pretrained weights
124124
model = AutoModel.from_pretrained("tfts-model")
125125
126126
127+
3.3 Save and load the model
128+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127129

128-
3.3 Serve the model
130+
131+
3.4 Serve the model
129132
~~~~~~~~~~~~~~~~~~~~~~~
130133
Once the model is trained and evaluated, deploy it for inference. Ensure the model is saved in a format compatible with your serving environment (e.g., TensorFlow SavedModel, ONNX, etc.). Set up an API or service to handle incoming requests, preprocess input data, and return predictions in real-time.
131134

tests/test_examples/test_tfts_inputs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_encoder_array(self):
2323
y_valid = np.random.rand(1, predict_sequence_length, 1)
2424

2525
for m in self.test_models:
26-
logger.info(f"Test model {m}")
26+
print(f"==== Test model {m} ====")
2727
config = AutoConfig.for_model(m)
2828
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
2929
trainer = KerasTrainer(model)
@@ -65,6 +65,7 @@ def test_encoder_decoder_array2(self):
6565
n_decoder_feature = 3
6666

6767
x_train = (
68+
# x, encoder, decoder
6869
np.random.rand(1, train_length, 1),
6970
np.random.rand(1, train_length, n_encoder_feature),
7071
np.random.rand(1, predict_sequence_length, n_decoder_feature),
@@ -78,6 +79,7 @@ def test_encoder_decoder_array2(self):
7879
y_valid = np.random.rand(1, predict_sequence_length, 1)
7980

8081
for m in self.test_models:
82+
print(f"==== Test model {m} ====")
8183
config = AutoConfig.for_model(m)
8284
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
8385
trainer = KerasTrainer(model)
@@ -116,6 +118,7 @@ def test_encoder_decoder_tfdata(self):
116118
valid_loader = valid_loader.batch(batch_size=1)
117119

118120
for m in self.test_models:
121+
print(f"==== Test model {m} ====")
119122
config = AutoConfig.for_model(m)
120123
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
121124
trainer = KerasTrainer(model)

tests/test_models/test_wavenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_encoder(self):
2626
def test_decoder1(self):
2727
filters = 32
2828
dilation_rates = [2]
29-
dense_hidden_size = 32
29+
dense_hidden_size = 64
3030
predict_sequence_length = 3
3131
layer = DecoderV1(filters, dilation_rates, dense_hidden_size, predict_sequence_length)
3232

tfts/layers/attention_layer.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def call(
5252
training: Optional[bool] = None,
5353
return_attention_scores: bool = False,
5454
use_causal_mask: bool = False,
55+
**kwargs,
5556
):
5657
"""use query and key generating an attention multiplier for value, multi_heads to repeat it
5758
@@ -110,12 +111,31 @@ def get_config(self):
110111
return dict(list(base_config.items()) + list(config.items()))
111112

112113
def compute_output_shape(self, input_shape):
113-
if isinstance(input_shape, (list, tuple)) and len(input_shape) == 3:
114-
q_shape = input_shape[0]
115-
else:
116-
raise ValueError("Expected input_shape to be a list or tuple of three elements (q, k, v)")
114+
if isinstance(input_shape, tuple) and len(input_shape) == 3:
115+
batch_size, seq_len, _ = input_shape
116+
return (batch_size, seq_len, self.hidden_size)
117+
118+
elif isinstance(input_shape, (list, tuple)) and len(input_shape) == 3:
119+
q_shape, k_shape, v_shape = input_shape
117120

118-
return (q_shape[0], q_shape[1], self.hidden_size)
121+
# Validate that all shapes are tuples with 3 dimensions
122+
if not all(isinstance(shape, tuple) and len(shape) == 3 for shape in [q_shape, k_shape, v_shape]):
123+
raise ValueError(
124+
"Each input shape must be a tuple of length 3 (batch_size, seq_len, features). "
125+
f"Got shapes: q={q_shape}, k={k_shape}, v={v_shape}"
126+
)
127+
128+
# Output shape is based on query sequence length
129+
batch_size, seq_q_len, _ = q_shape
130+
return (batch_size, seq_q_len, self.hidden_size)
131+
132+
else:
133+
raise ValueError(
134+
"Expected input_shape to be either:\n"
135+
"1. A single tuple (batch_size, seq_len, features) for self-attention, or\n"
136+
"2. A list/tuple of 3 shapes [(q_shape), (k_shape), (v_shape)] for cross-attention.\n"
137+
f"Got: {input_shape}"
138+
)
119139

120140

121141
class SelfAttention(tf.keras.layers.Layer):
@@ -161,9 +181,6 @@ def get_config(self):
161181
return base_config
162182

163183
def compute_output_shape(self, input_shape):
164-
"""
165-
Compute the output shape of the self-attention layer.
166-
"""
167184
return (input_shape[0], input_shape[1], self.hidden_size)
168185

169186

tfts/layers/cnn_layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def build(self, input_shape: Tuple[int]) -> None:
6969
input_shape : Tuple[int]
7070
Shape of the input tensor
7171
"""
72+
super(ConvTemp, self).build(input_shape)
7273
self.conv = tf.keras.layers.Conv1D(
7374
kernel_size=self.kernel_size,
7475
kernel_initializer=initializers.get(self.kernel_initializer),
@@ -77,7 +78,8 @@ def build(self, input_shape: Tuple[int]) -> None:
7778
dilation_rate=self.dilation_rate,
7879
activation=activations.get(self.activation),
7980
)
80-
super(ConvTemp, self).build(input_shape)
81+
self.conv.build(input_shape)
82+
self.built = True
8183

8284
def call(self, inputs):
8385
"""Forward pass of the layer.

tfts/layers/dense_layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,13 @@ def get_config(self):
7676
base_config = super(DenseTemp, self).get_config()
7777
return dict(list(base_config.items()) + list(config.items()))
7878

79+
def compute_output_shape(self, input_shape):
80+
return tf.TensorShape(input_shape[:-1] + (self.hidden_size,))
81+
7982

8083
class FeedForwardNetwork(tf.keras.layers.Layer):
81-
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
82-
super(FeedForwardNetwork, self).__init__()
84+
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0, **kwargs):
85+
super(FeedForwardNetwork, self).__init__(**kwargs)
8386
self.hidden_size = hidden_size
8487
self.intermediate_size = intermediate_size
8588
self.hidden_dropout_prob = hidden_dropout_prob

tfts/layers/embed_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import tensorflow as tf
66
from tensorflow.keras.layers import GRU, Embedding
7+
from tensorflow.keras.utils import register_keras_serializable
78

89
from .position_layer import PositionalEmbedding, PositionalEncoding, RelativePositionEmbedding
910

tfts/layers/nbeats_layer.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def seasonality_model(
100100
"bp,pt->bt", theta[:, config_per_harmonic : 2 * config_per_harmonic], forecast_sin_template
101101
)
102102
forecast = forecast_harmonics_sin + forecast_harmonics_cos
103-
104103
return backcast, forecast
105104

106105

@@ -123,9 +122,14 @@ class GenericBlock(tf.keras.layers.Layer):
123122
"""
124123

125124
def __init__(
126-
self, train_sequence_length: int, predict_sequence_length: int, hidden_size: int, n_block_layers: int = 4
125+
self,
126+
train_sequence_length: int,
127+
predict_sequence_length: int,
128+
hidden_size: int,
129+
n_block_layers: int = 4,
130+
**kwargs
127131
):
128-
super(GenericBlock, self).__init__()
132+
super(GenericBlock, self).__init__(**kwargs)
129133
self.train_sequence_length = train_sequence_length
130134
self.predict_sequence_length = predict_sequence_length
131135
self.hidden_size = hidden_size
@@ -139,9 +143,9 @@ def build(self, input_shape: Tuple[Optional[int], ...]):
139143
input_shape : Tuple[Optional[int], ...]
140144
Shape of the input tensor
141145
"""
146+
super(GenericBlock, self).build(input_shape)
142147
self.layers = [Dense(self.hidden_size, activation="relu") for _ in range(self.n_block_layers)]
143148
self.theta = Dense(self.train_sequence_length + self.predict_sequence_length, use_bias=False, activation=None)
144-
super(GenericBlock, self).build(input_shape)
145149

146150
def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
147151
"""Compute the output of the Generic Block.
@@ -164,6 +168,24 @@ def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
164168
x = self.theta(x)
165169
return generic_model(x, tf.range(self.train_sequence_length), tf.range(self.predict_sequence_length))
166170

171+
def compute_output_shape(self, input_shape):
172+
batch_size = input_shape[0]
173+
backcast_shape = (batch_size, self.train_sequence_length)
174+
forecast_shape = (batch_size, self.predict_sequence_length)
175+
return (backcast_shape, forecast_shape)
176+
177+
def get_config(self):
178+
config = super().get_config()
179+
config.update(
180+
{
181+
"train_sequence_length": self.train_sequence_length,
182+
"predict_sequence_length": self.predict_sequence_length,
183+
"hidden_size": self.hidden_size,
184+
"n_block_layers": self.n_block_layers,
185+
}
186+
)
187+
return config
188+
167189

168190
class TrendBlock(tf.keras.layers.Layer):
169191
"""Trend block that learns trend patterns using polynomial basis functions.
@@ -192,8 +214,9 @@ def __init__(
192214
hidden_size: int,
193215
n_block_layers: int = 4,
194216
polynomial_term: int = 2,
217+
**kwargs
195218
):
196-
super().__init__()
219+
super().__init__(**kwargs)
197220

198221
self.train_sequence_length = train_sequence_length
199222
self.predict_sequence_length = predict_sequence_length
@@ -226,12 +249,10 @@ def build(self, input_shape: Tuple[Optional[int], ...]):
226249
input_shape : Tuple[Optional[int], ...]
227250
Shape of the input tensor
228251
"""
229-
252+
super().build(input_shape)
230253
self.layers = [Dense(self.hidden_size, activation="relu") for _ in range(self.n_block_layers)]
231254
self.theta = Dense(2 * self.polynomial_size, use_bias=False, activation=None)
232255

233-
super().build(input_shape)
234-
235256
def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
236257
"""Compute the output of the Trend Block.
237258
@@ -254,14 +275,16 @@ def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
254275
return trend_model(x, self.backcast_time, self.forecast_time, self.polynomial_size)
255276

256277
def compute_output_shape(self, input_shape):
257-
return [(input_shape[0], self.train_sequence_length), (input_shape[0], self.predict_sequence_length)]
278+
return ((input_shape[0], self.train_sequence_length), (input_shape[0], self.predict_sequence_length))
258279

259280

260281
class SeasonalityBlock(tf.keras.layers.Layer):
261282
"""Seasonality block"""
262283

263-
def __init__(self, train_sequence_length, predict_sequence_length, hidden_size, n_block_layers=4, num_harmonics=1):
264-
super().__init__()
284+
def __init__(
285+
self, train_sequence_length, predict_sequence_length, hidden_size, n_block_layers=4, num_harmonics=1, **kwargs
286+
):
287+
super().__init__(**kwargs)
265288
self.train_sequence_length = train_sequence_length
266289
self.predict_sequence_length = predict_sequence_length
267290
self.hidden_size = hidden_size
@@ -300,6 +323,7 @@ def __init__(self, train_sequence_length, predict_sequence_length, hidden_size,
300323
self.forecast_sin_template = tf.transpose(tf.sin(self.forecast_grid))
301324

302325
def build(self, input_shape: Tuple[Optional[int], ...]):
326+
super().build(input_shape)
303327
self.layers = [Dense(self.hidden_size, activation="relu") for _ in range(self.n_block_layers)]
304328
self.theta = Dense(self.theta_size, use_bias=False, activation=None)
305329

@@ -336,17 +360,21 @@ def call(self, inputs):
336360
self.forecast_sin_template,
337361
)
338362

339-
340-
class ZerosLayer(tf.keras.layers.Layer):
341-
"""Layer for creating zeros tensor with proper shape"""
342-
343-
def __init__(self, predict_length, **kwargs):
344-
super(ZerosLayer, self).__init__(**kwargs)
345-
self.predict_length = predict_length
346-
347-
def call(self, x):
348-
batch_size = tf.shape(x)[0]
349-
return tf.zeros([batch_size, self.predict_length], dtype=tf.float32)
350-
351363
def compute_output_shape(self, input_shape):
352-
return (input_shape[0], self.predict_length)
364+
batch_size = input_shape[0]
365+
backcast_shape = (batch_size, self.train_sequence_length)
366+
forecast_shape = (batch_size, self.predict_sequence_length)
367+
return (backcast_shape, forecast_shape)
368+
369+
def get_config(self):
370+
config = super().get_config()
371+
config.update(
372+
{
373+
"train_sequence_length": self.train_sequence_length,
374+
"predict_sequence_length": self.predict_sequence_length,
375+
"hidden_size": self.hidden_size,
376+
"n_block_layers": self.n_block_layers,
377+
"num_harmonics": self.num_harmonics,
378+
}
379+
)
380+
return config

0 commit comments

Comments
 (0)