Skip to content

Commit d55940d

Browse files
committed
Fix ops.tile shape inference issue on TensorFlow backend
When ops.tile is called inside a Layer's call method with concrete integer repeats, the TensorFlow backend was converting those repeats to a tensor, which prevented TensorFlow's shape inference from properly determining the output shape. This resulted in all-None shapes. Changes: 1. Modified TensorFlow backend's tile() to detect when repeats contains only concrete integer values and pass them directly to tf.tile as a Python list/tuple instead of converting to a tensor. This allows TensorFlow's shape inference to work correctly. 2. Enhanced ops.numpy.Tile.compute_output_spec() to handle symbolic repeat values more gracefully by checking if each repeat is a concrete integer before attempting multiplication. 3. Added regression tests to verify shape inference works correctly both in direct ops.tile calls and when used inside Layer.call(). Fixes #20914 Signed-off-by: Samaresh Kumar Singh <[email protected]>
1 parent e048ae4 commit d55940d

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

keras/src/backend/tensorflow/numpy.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,6 +2742,42 @@ def round(x, decimals=0):
27422742

27432743
def tile(x, repeats):
27442744
x = convert_to_tensor(x)
2745+
2746+
# Check if repeats contains only concrete integers
2747+
# If so, keep it as a Python list/tuple for better shape inference
2748+
try:
2749+
if isinstance(repeats, (list, tuple)):
2750+
# Try to extract concrete integer values
2751+
concrete_repeats = []
2752+
for r in repeats:
2753+
if isinstance(r, int):
2754+
concrete_repeats.append(r)
2755+
elif hasattr(r, 'numpy') and r.shape == ():
2756+
# Scalar tensor with concrete value
2757+
concrete_repeats.append(int(r.numpy()))
2758+
else:
2759+
# Not a concrete value, fall back to tensor path
2760+
concrete_repeats = None
2761+
break
2762+
2763+
if concrete_repeats is not None:
2764+
# Use concrete repeats directly for better shape inference
2765+
repeats = concrete_repeats
2766+
# Pad or trim repeats to match x rank
2767+
x_rank = x.shape.rank
2768+
if x_rank is not None:
2769+
if len(repeats) < x_rank:
2770+
repeats = [1] * (x_rank - len(repeats)) + repeats
2771+
elif len(repeats) > x_rank:
2772+
# Need to reshape x to match repeats length
2773+
x_shape_list = [1] * (len(repeats) - x_rank) + [d if d is not None else -1 for d in x.shape.as_list()]
2774+
x = tf.reshape(x, x_shape_list)
2775+
return tf.tile(x, repeats)
2776+
except Exception:
2777+
# If anything goes wrong, fall back to original implementation
2778+
pass
2779+
2780+
# Original dynamic implementation for non-concrete repeats
27452781
repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
27462782
repeats_size = tf.size(repeats)
27472783
repeats = tf.pad(

keras/src/ops/numpy.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6411,17 +6411,31 @@ def compute_output_spec(self, x):
64116411
repeats = self.repeats
64126412
if isinstance(repeats, int):
64136413
repeats = [repeats]
6414+
6415+
# Convert repeats to list if it's a tuple or other iterable
6416+
# and extract concrete integer values
6417+
if not isinstance(repeats, list):
6418+
try:
6419+
repeats = list(repeats)
6420+
except TypeError:
6421+
repeats = [repeats]
6422+
64146423
if len(x_shape) > len(repeats):
64156424
repeats = [1] * (len(x_shape) - len(repeats)) + repeats
64166425
else:
64176426
x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape
64186427

64196428
output_shape = []
64206429
for x_size, repeat in zip(x_shape, repeats):
6430+
# Check if repeat is a concrete integer value
6431+
# If it's a symbolic tensor or unknown, we can't infer the size
64216432
if x_size is None:
64226433
output_shape.append(None)
6423-
else:
6434+
elif isinstance(repeat, int):
64246435
output_shape.append(x_size * repeat)
6436+
else:
6437+
# repeat is symbolic (e.g., KerasTensor, tf.Tensor, etc.)
6438+
output_shape.append(None)
64256439
return KerasTensor(output_shape, dtype=x.dtype)
64266440

64276441

keras/src/ops/numpy_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,10 @@ def test_tile(self):
18201820
self.assertEqual(knp.tile(x, [2]).shape, (None, 6))
18211821
self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6))
18221822
self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6))
1823+
1824+
# Test with multi-dimensional input
1825+
x = KerasTensor((None, 3, 2, 2))
1826+
self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2))
18231827

18241828
def test_trace(self):
18251829
x = KerasTensor((None, 3, None, 5))
@@ -9507,3 +9511,23 @@ def call(self, x):
95079511
model.compile(jit_compile=jit_compile)
95089512

95099513
model.predict(np.random.randn(1, 8))
9514+
9515+
def test_tile_shape_inference_in_layer(self):
9516+
"""Test that ops.tile properly infers output shape when used in a Layer.
9517+
9518+
This is a regression test for issue #20914 where TensorFlow backend
9519+
would return all-None shapes when tile was called inside a Layer's
9520+
call method with concrete integer repeats.
9521+
"""
9522+
class TileLayer(keras.layers.Layer):
9523+
def call(self, x):
9524+
# Use concrete integer repeats
9525+
repeats = [1, 2, 1, 1]
9526+
return knp.tile(x, repeats)
9527+
9528+
inputs = keras.Input(shape=(3, 2, 2))
9529+
output = TileLayer()(inputs)
9530+
9531+
# With the fix, output shape should be (None, 6, 2, 2)
9532+
# Before the fix, it was (None, None, None, None)
9533+
self.assertEqual(output.shape, (None, 6, 2, 2))

0 commit comments

Comments
 (0)