Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions keras/src/layers/preprocessing/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.data_layer import DataLayer
from keras.src.trainers.data_adapters import get_data_adapter
from keras.src.utils.module_utils import tensorflow as tf
from keras.utils import PyDataset


@keras_export("keras.layers.Normalization")
Expand Down Expand Up @@ -229,6 +231,26 @@ def adapt(self, data):
# Batch dataset if it isn't batched
data = data.batch(128)
input_shape = tuple(data.element_spec.shape)
elif isinstance(data, PyDataset):
# as PyDatasets returns tuples of input/annotation pairs
adapter = get_data_adapter(data)
tf_dataset = adapter.get_tf_dataset()
if len(tf_dataset.element_spec) == 1:
# just x
data = tf_dataset.map(lambda x: x)
elif len(tf_dataset.element_spec) == 2:
# (x, y) pairs
data = tf_dataset.map(lambda x, y: x)
elif len(tf_dataset.element_spec) == 3:
# (x, y, sample_weight) tuples
data = tf_dataset.map(lambda x, y, z: x)
input_shape = data.element_spec.shape
Comment on lines +237 to +247
Copy link

@limzikiki limzikiki Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming from your comment.
What I did in my solution is:

        elif isinstance(data, keras.utils.PyDataset):
            sample_input = data[0][0] # pydataset should return a tuple with first element being the data
            if isinstance(sample_input, np.ndarray) or backend.is_tensor(sample_input):
                input_shape = sample_input.shape
            else:
                raise ValueError(f"Unsupported data type: {type(sample_input)} returned from the PyDataset")

The advantage of my option lies in the fact that we don’t need to perform excessive transformations to tf tensors just for the sake of size estimation. PyDataset is also used for experimentation, and when the dataset is too large to be read into RAM, which is common for workstations and personal devices, transforming PyDataset into a TF Tensor will fail due to memmory allocation. However on contrary the drawback of my solution is that it retrieves the first batch, and during that the first batch might get changed and second retrieval of the first batch might return not the same output (If someone implemented non-idempotent PyDataset, but i think then it is a user problem then). Also why retrieval of the first batch is feasable solution is because shape of all the elements across all the batches must be identical for normalization to work correctly.

Considering strategic direction of Keras to move away from being solely dependant on TensorFlow, adding transformation to tensorflow creates a technical debt that Keras team should later take care off.

I am open for a discussion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the thorough response! Yes, I will defer to the core developers' judgement for this. Happy to revise and infer the shape based on sampling a batch if we think that's the better approach.

else:
raise TypeError(
f"Unsupported data type: {type(data)}. `adapt` supports "
f"`np.ndarray`, backend tensors, `tf.data.Dataset`, and "
f"`keras.utils.PyDataset`."
)

if not self.built:
self.build(input_shape)
Expand Down
32 changes: 32 additions & 0 deletions keras/src/layers/preprocessing/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,35 @@ def test_normalization_with_scalar_mean_var(self):
input_data = np.array([[1, 2, 3]], dtype="float32")
layer = layers.Normalization(mean=3.0, variance=2.0)
layer(input_data)

@parameterized.parameters([("x",), ("x_and_y",), ("x_y_and_weights")])
def test_adapt_pydataset_compat(self, pydataset_type):
import keras

class CustomDataset(keras.utils.PyDataset):
def __len__(self):
return 100

def __getitem__(self, idx):
x = np.random.rand(32, 32, 3)
y = np.random.randint(0, 10, size=(1,))
weights = np.random.randint(0, 10, size=(1,))
if pydataset_type == "x":
return x
elif pydataset_type == "x_and_y":
return x, y
elif pydataset_type == "x_y_and_weights":
return x, y, weights
else:
raise NotImplementedError(pydataset_type)

normalizer = keras.layers.Normalization()
normalizer.adapt(CustomDataset())
self.assertTrue(normalizer.built)
self.assertIsNotNone(normalizer.mean)
self.assertIsNotNone(normalizer.variance)
self.assertEqual(normalizer.mean.shape[-1], 3)
self.assertEqual(normalizer.variance.shape[-1], 3)
sample_input = np.random.rand(1, 32, 32, 3)
output = normalizer(sample_input)
self.assertEqual(output.shape, (1, 32, 32, 3))