Skip to content

Commit 9c675a9

Browse files
authored
Bug fix for TF backend using random zoom with flatten (#18835)
* bug fix * skip the numpy backend
1 parent 7a65022 commit 9c675a9

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

keras/backend/tensorflow/image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def affine_transform(
119119
interpolation=interpolation.upper(),
120120
fill_mode=fill_mode.upper(),
121121
)
122+
affined = tf.ensure_shape(affined, image.shape)
122123

123124
if data_format == "channels_first":
124125
affined = tf.transpose(affined, (0, 3, 1, 2))

keras/layers/preprocessing/random_zoom_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
from absl.testing import parameterized
34
from tensorflow import data as tf_data
45

@@ -132,3 +133,19 @@ def test_dynamic_shape(self):
132133
)(inputs)
133134
model = models.Model(inputs, outputs)
134135
model.predict(np.random.random((1, 6, 6, 3)))
136+
137+
@pytest.mark.skipif(
138+
backend.backend() == "numpy",
139+
reason="The NumPy backend does not implement fit.",
140+
)
141+
def test_connect_with_flatten(self):
142+
model = models.Sequential(
143+
[
144+
layers.RandomZoom((-0.5, 0.0), (-0.5, 0.0)),
145+
layers.Flatten(),
146+
layers.Dense(1, activation="relu"),
147+
],
148+
)
149+
150+
model.compile(loss="mse")
151+
model.fit(np.random.random((2, 2, 2, 1)), y=np.random.random((2,)))

0 commit comments

Comments
 (0)