Skip to content

Custom Keras RNN with constants changes constants shape when saving #770

Closed
@claCase

Description

@claCase

Constants shape changes from rank 2 during inference to rank 3 tensor during model saving

System Info:

  • TensorFlow version (Cuda): 2.15
  • OS platform and distribution: Windows WSL
  • Python version: 3.11
from tensorflow.python.keras.layers.recurrent import (
    DropoutRNNCellMixin,
    _config_for_enable_caching_device,
    _caching_device,
)
import tensorflow as tf

class RNNWithConstants(
    DropoutRNNCellMixin, tf.keras.__internal__.layers.BaseRandomLayer
):
    def __init__(
        self,
        units,
        activation,
        recurrent_activation,
        dropout,
        recurrent_dropout,
        **kwargs,
    ):
        super(RNNWithConstants, self).__init__(**kwargs)
        self.units = units
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout
        self.recurrent_activation = recurrent_activation
        self.cell = tf.keras.layers.GRUCell(
            units=units,
            activation=activation,
            recurrent_activation=recurrent_activation,
            recurrent_dropout=recurrent_dropout,
            dropout=dropout,
        )
        self.state_size = units
        self.output_size = units

    @tf.function
    def call(self, inputs, states, constants):
        print(f"inputs {inputs.shape}")
        print(f"states {states[0].shape}")
        print(f"constants {constants[0].shape}")
        inputs = tf.concat([inputs, constants[0]], axis=-1)  # error due to shape change
        h, _ = self.cell(inputs, states)
        return h, h
    
class ConstantsModel(tf.keras.models.Model):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units 
        self.cell = RNNWithConstants(units, "sigmoid", "sigmoid", 0.1, 0.1)
        self.rnn = tf.keras.layers.RNN(self.cell)

    @tf.function
    def call(self, inputs, constants, training):
        return self.rnn(inputs, constants=constants)
    
const = ConstantsModel(10)
print("initializing...")
_ = const(tf.random.normal(shape=(100, 50, 10)), tf.random.normal(shape=(100, 10)))
print("\nsaving....")
const.save("./const_model")

Relevant Log Output

initializing...
inputs (100, 10)
states (100, 10)
constants (100, 10)
inputs (100, 10)
states (100, 10)
constants (100, 10)

saving....
inputs (None, 10)
states (None, 10)
constants (None, 10)
inputs (None, 10)
states (None, 10)
constants (None, None, 10)

Keras Issue

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions