Closed
Description
Constants shape changes from rank 2 during inference to rank 3 tensor during model saving
System Info:
- TensorFlow version (Cuda): 2.15
- OS platform and distribution: Windows WSL
- Python version: 3.11
from tensorflow.python.keras.layers.recurrent import (
DropoutRNNCellMixin,
_config_for_enable_caching_device,
_caching_device,
)
import tensorflow as tf
class RNNWithConstants(
DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
def __init__(
self,
units,
activation,
recurrent_activation,
dropout,
recurrent_dropout,
**kwargs,
):
super(RNNWithConstants, self).__init__(**kwargs)
self.units = units
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.recurrent_activation = recurrent_activation
self.cell = tf.keras.layers.GRUCell(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
self.state_size = units
self.output_size = units
@tf.function
def call(self, inputs, states, constants):
print(f"inputs {inputs.shape}")
print(f"states {states[0].shape}")
print(f"constants {constants[0].shape}")
inputs = tf.concat([inputs, constants[0]], axis=-1) # error due to shape change
h, _ = self.cell(inputs, states)
return h, h
class ConstantsModel(tf.keras.models.Model):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
self.rnn = tf.keras.layers.RNN(self.cell)
@tf.function
def call(self, inputs, constants, training):
return self.rnn(inputs, constants=constants)
const = ConstantsModel(10)
print("initializing...")
_ = const(tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10)))
print("\nsaving....")
const.save("./const_model")
Relevant Log Output
initializing...
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)
saving....
inputs (None, 10)
states (None, 10)
constants (None, 10)
inputs (None, 10)
states (None, 10)
constants (None, None, 10)