Skip to content

Spectral Normalization layer does not work #816

Open
@dryglicki

Description

@dryglicki

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04.4 LTS
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.18.0
  • Python version: 3.11.11
  • Bazel version (if compiling from source): N/A
  • GPU model and memory: RTX Ada A6000, 48 GB VRAM

Describe the problem.

The SpectralNormalization Wrapper does not work.

Describe the current behavior.

When used, it produces the following error:

Traceback (most recent call last):
  File "/home/dryglicki/code/testing/test_specnorm.py", line 15, in <module>
    x = KL.SpectralNormalization(KL.Conv2D(32, 3, padding='same'))(inputs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d18_py3d11/lib/python3.11/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/ssd0/miniforge3_2024-04/envs/tensorflow_2d18_py3d11/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1288, in input_spec
    raise TypeError(
TypeError: Layer input_spec must be an instance of InputSpec. Got: InputSpec(shape=(None, None, None, 12), ndim=4)

Describe the expected behavior.

For SpectralNormalization to work.

Contributing.

  • Do you want to contribute a PR? (yes/no): no

Standalone code to reproduce the issue.

import tensorflow as tf
import tf_keras as K
import tf_keras.layers as KL

nb = 15
nt = 6
nx = 32
ny = 32
nc = 12

xx = tf.random.normal((nb, nx, ny, nc))
yy = tf.random.normal((nb, nx, ny, nc))

inputs = KL.Input( [None, None, nc], name = 'dummy_input')
x = KL.SpectralNormalization(KL.Conv2D(32, 3, padding='same'))(inputs)
# x = KL.Conv2D(32, 3, padding='same')(inputs)
x = KL.Conv2D(nc, 1, padding='same')(x)

model = K.models.Model(inputs = inputs, outputs = x)

model.compile(loss = 'mae')

model.fit(xx,yy,epochs = 10)

Swapping the Conv2D(32) lines results in functioning code.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions