Skip to content

weight normalization layers throws error with TPU Strategy #1703

Open
@sourcecode369

Description

@sourcecode369

Environment : Google Colab
Hardware Accelerator: TPU
TensorFlow version: 2.2.0-rc3
TensorFlow Addons: 0.9.1

Describe the bug

Weight Normalization layer (tfa.layers.WeightNormalization) throws error when running in TPU

Code to reproduce the issue

import tensorflow as tf
print(f"tf.__version__: {tf.__version__}")
tf.config.optimizer.set_jit(True)
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import tensorflow_datasets as tfds
import os

try: 
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    strategy = tf.distribute.get_strategy()
    raise BaseException('ERROR: Not connected to a TPU runtime.')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

print("REPLICAS: ", tpu_strategy.num_replicas_in_sync)

def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(10000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
  test_dataset = mnist_test.map(scale).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

  return train_dataset, test_dataset

def create_model():
  return tf.keras.Sequential(
      [tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(32, 3, activation="elu", input_shape=(28, 28, 1))),
       tf.keras.layers.Flatten(),
       tfa.layers.WeightNormalization(tf.keras.layers.Dense(128, "elu")),
       tfa.layers.WeightNormalization(tf.keras.layers.Dense(10))])
  
train_dataset, test_dataset = get_dataset()

with tpu_strategy.scope():
  model = create_model()
  model.compile(optimizer=tfa.optimizers.Lookahead(tfa.optimizers.RectifiedAdam(
                          lr=1e-3,
                          total_steps=10000,
                          warmup_proportion=0.1,
                          min_lr=1e-5,
                      ),sync_period=6, slow_step_size=0.5),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])
  
model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)

Other info / logs


TypeError                                 Traceback (most recent call last)
<ipython-input-1-84d7c4d8b77c> in <module>()
     57                 metrics=['sparse_categorical_accuracy'])
     58 
---> 59 model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    849                 batch_size=batch_size):
    850               callbacks.on_train_batch_begin(step)
--> 851               tmp_logs = train_function(iterator)
    852               # Catch OutOfRangeError for Datasets of unknown size.
    853               # This blocks until the batch has finished executing.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    625       # This is the first call of __call__, so we have to initialize.
    626       initializers = []
--> 627       self._initialize(args, kwds, add_initializers_to=initializers)
    628     finally:
    629       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    504     self._concrete_stateful_fn = (
    505         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 506             *args, **kwds))
    507 
    508     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2444       args, kwargs = None, None
   2445     with self._lock:
-> 2446       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2447     return graph_function
   2448 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args, kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2665             arg_names=arg_names,
   2666             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667             capture_by_value=self._capture_by_value),
   2668         self._function_attributes,
   2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    439         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    440         # the function a weak reference to itself to avoid a reference cycle.
--> 441         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    442     weak_wrapped_fn = weakref.ref(wrapped_fn)
    443 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:170 run  **
        return self.extended.tpu_run(fn, args, kwargs, options)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:863 tpu_run
        return func(args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:930 tpu_function
        padding_spec=padding_spec)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:893 replicate
        padding_spec=padding_spec)[1]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:1280 split_compile_and_replicate
        outputs = computation(*computation_inputs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:892 replicated_fn
        result[0] = fn(*replica_args, **replica_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:531 train_step  **
        y_pred = self(x, training=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/sequential.py:291 call
        outputs = layer(inputs, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:897 __call__
        self._maybe_build(inputs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:2416 _maybe_build
        self.build(input_shapes)  # pylint:disable=not-callable
    /usr/local/lib/python3.6/dist-packages/tensorflow_addons/layers/wrappers.py:119 build
        self._naked_clone_layer.set_weights(self.layer.get_weights())
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:1588 get_weights
        return backend.batch_get_value(output_weights)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 batch_get_value
        return [x.numpy() for x in tensors]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 <listcomp>
        return [x.numpy() for x in tensors]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:102 numpy
        return self.read_value().numpy()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:135 read_value
        return self._read_variable_op()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:129 _read_variable_op
        return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:475 read_variable_op
        resource, dtype=dtype, name=name, ctx=_ctx)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:502 read_variable_op_eager_fallback
        attrs=_attrs, ctx=ctx, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
        inputs, attrs, num_outputs)

    TypeError: An op outside of the function building code is being passed
    a "Graph" tensor. It is possible to have Graph tensors
    leak out of the function building context by including a
    tf.init_scope in your function building code.
    For example, the following function will fail:
      @tf.function
      def has_init_scope():
        my_constant = tf.constant(1.)
        with tf.init_scope():
          added = my_constant * 2
    The graph tensor has name: sequential/weight_normalization/sequential/weight_normalization/kernel_140486508184800/handle:0

Activity

sourcecode369

sourcecode369 commented on Apr 18, 2020

@sourcecode369
Author

It throws error for distribution strategy of GPU as well. Can check here - #740

Thanks and hope to see a fix soon.

evanatyourservice

evanatyourservice commented on Aug 29, 2021

@evanatyourservice

For what it's worth, I was able to get it to work with data_init=False, because the problem in your error code, self._naked_clone_layer.set_weights(self.layer.get_weights()), is only called if data_init=True.

evanatyourservice

evanatyourservice commented on Aug 29, 2021

@evanatyourservice

I actually thought this might be a problem before I even tried tfa weightnorm on TPU... The real question for this issue is, how to do weightnorm paper's data-dependent weight init in a tf distributed strategy?

dharmanibc

dharmanibc commented on Mar 20, 2023

@dharmanibc

The problem still persist for both GPU and TPU. Even a simple sample-code shows the same error.
x = np.random.rand(1, 10, 10, 1)
dense = addon_layers.WeightNormalization(tf.keras.layers.Dense(10),
data_init=False)
y = dense(x)
y.shape


TypeError Traceback (most recent call last)
in
7
8 x = np.random.rand(1, 10, 10, 1)
----> 9 dense = addon_layers.WeightNormalization(tf.keras.layers.Dense(10), data_init=False)
10 y = dense(x)
11 y.shape

2 frames
/usr/local/lib/python3.9/dist-packages/tensorflow_addons/layers/wrappers.py in init(self, layer, data_init, **kwargs)
57
58 @TypeChecked
---> 59 def init(self, layer: tf.keras.layers, data_init: bool = True, **kwargs):
60 super().init(layer, **kwargs)
61 self.data_init = data_init

/usr/local/lib/python3.9/dist-packages/typeguard/_functions.py in check_argument_types(memo)
111 value = memo.arguments[argname]
112 try:
--> 113 check_type_internal(value, expected_type, memo=memo)
114 except TypeCheckError as exc:
115 qualname = qualified_name(value, add_class_prefix=True)

/usr/local/lib/python3.9/dist-packages/typeguard/_checkers.py in check_type_internal(value, annotation, memo)
668 return
669
--> 670 if not isinstance(value, origin_type):
671 raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
672

TypeError: isinstance() arg 2 must be a type or tuple of types

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      weight normalization layers throws error with TPU Strategy · Issue #1703 · tensorflow/addons