Open
Description
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04.4 LTS
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.18.0
- Python version: 3.11.11
- Bazel version (if compiling from source): N/A
- GPU model and memory: RTX Ada A6000, 48 GB VRAM
Describe the problem.
The SpectralNormalization
Wrapper does not work.
Describe the current behavior.
When used, it produces the following error:
Traceback (most recent call last):
File "/home/dryglicki/code/testing/test_specnorm.py", line 15, in <module>
x = KL.SpectralNormalization(KL.Conv2D(32, 3, padding='same'))(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d18_py3d11/lib/python3.11/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d18_py3d11/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1288, in input_spec
raise TypeError(
TypeError: Layer input_spec must be an instance of InputSpec. Got: InputSpec(shape=(None, None, None, 12), ndim=4)
Describe the expected behavior.
For SpectralNormalization
to work.
- Do you want to contribute a PR? (yes/no): no
Standalone code to reproduce the issue.
import tensorflow as tf
import tf_keras as K
import tf_keras.layers as KL
nb = 15
nt = 6
nx = 32
ny = 32
nc = 12
xx = tf.random.normal((nb, nx, ny, nc))
yy = tf.random.normal((nb, nx, ny, nc))
inputs = KL.Input( [None, None, nc], name = 'dummy_input')
x = KL.SpectralNormalization(KL.Conv2D(32, 3, padding='same'))(inputs)
# x = KL.Conv2D(32, 3, padding='same')(inputs)
x = KL.Conv2D(nc, 1, padding='same')(x)
model = K.models.Model(inputs = inputs, outputs = x)
model.compile(loss = 'mae')
model.fit(xx,yy,epochs = 10)
Swapping the Conv2D(32)
lines results in functioning code.