Open
Description
Environment : Google Colab
Hardware Accelerator: TPU
TensorFlow version: 2.2.0-rc3
TensorFlow Addons: 0.9.1
Describe the bug
Weight Normalization layer (tfa.layers.WeightNormalization) throws error when running in TPU
Code to reproduce the issue
import tensorflow as tf
print(f"tf.__version__: {tf.__version__}")
tf.config.optimizer.set_jit(True)
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import tensorflow_datasets as tfds
import os
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR']) # TPU detection
print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
strategy = tf.distribute.get_strategy()
raise BaseException('ERROR: Not connected to a TPU runtime.')
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("REPLICAS: ", tpu_strategy.num_replicas_in_sync)
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(10000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
test_dataset = mnist_test.map(scale).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return train_dataset, test_dataset
def create_model():
return tf.keras.Sequential(
[tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(32, 3, activation="elu", input_shape=(28, 28, 1))),
tf.keras.layers.Flatten(),
tfa.layers.WeightNormalization(tf.keras.layers.Dense(128, "elu")),
tfa.layers.WeightNormalization(tf.keras.layers.Dense(10))])
train_dataset, test_dataset = get_dataset()
with tpu_strategy.scope():
model = create_model()
model.compile(optimizer=tfa.optimizers.Lookahead(tfa.optimizers.RectifiedAdam(
lr=1e-3,
total_steps=10000,
warmup_proportion=0.1,
min_lr=1e-5,
),sync_period=6, slow_step_size=0.5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)
Other info / logs
TypeError Traceback (most recent call last)
<ipython-input-1-84d7c4d8b77c> in <module>()
57 metrics=['sparse_categorical_accuracy'])
58
---> 59 model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
849 batch_size=batch_size):
850 callbacks.on_train_batch_begin(step)
--> 851 tmp_logs = train_function(iterator)
852 # Catch OutOfRangeError for Datasets of unknown size.
853 # This blocks until the batch has finished executing.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args, **kwds)
581
582 if tracing_count == self._get_tracing_count():
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
625 # This is the first call of __call__, so we have to initialize.
626 initializers = []
--> 627 self._initialize(args, kwds, add_initializers_to=initializers)
628 finally:
629 # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
504 self._concrete_stateful_fn = (
505 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 506 *args, **kwds))
507
508 def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2444 args, kwargs = None, None
2445 with self._lock:
-> 2446 graph_function, _, _ = self._maybe_define_function(args, kwargs)
2447 return graph_function
2448
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)
-> 2777 graph_function = self._create_graph_function(args, kwargs)
2778 self._function_cache.primary[cache_key] = graph_function
2779 return graph_function, args, kwargs
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2665 arg_names=arg_names,
2666 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667 capture_by_value=self._capture_by_value),
2668 self._function_attributes,
2669 # Tell the ConcreteFunction to clean up its graph once it goes out of
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
439 # __wrapped__ allows AutoGraph to swap in a converted function. We give
440 # the function a weak reference to itself to avoid a reference cycle.
--> 441 return weak_wrapped_fn().__wrapped__(*args, **kwds)
442 weak_wrapped_fn = weakref.ref(wrapped_fn)
443
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
TypeError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:170 run **
return self.extended.tpu_run(fn, args, kwargs, options)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:863 tpu_run
return func(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:930 tpu_function
padding_spec=padding_spec)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:893 replicate
padding_spec=padding_spec)[1]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:1280 split_compile_and_replicate
outputs = computation(*computation_inputs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:892 replicated_fn
result[0] = fn(*replica_args, **replica_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:531 train_step **
y_pred = self(x, training=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/sequential.py:291 call
outputs = layer(inputs, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:897 __call__
self._maybe_build(inputs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:2416 _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
/usr/local/lib/python3.6/dist-packages/tensorflow_addons/layers/wrappers.py:119 build
self._naked_clone_layer.set_weights(self.layer.get_weights())
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:1588 get_weights
return backend.batch_get_value(output_weights)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 batch_get_value
return [x.numpy() for x in tensors]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 <listcomp>
return [x.numpy() for x in tensors]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:102 numpy
return self.read_value().numpy()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:135 read_value
return self._read_variable_op()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:129 _read_variable_op
return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:475 read_variable_op
resource, dtype=dtype, name=name, ctx=_ctx)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:502 read_variable_op_eager_fallback
attrs=_attrs, ctx=ctx, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
inputs, attrs, num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: sequential/weight_normalization/sequential/weight_normalization/kernel_140486508184800/handle:0
Activity
sourcecode369 commentedon Apr 18, 2020
It throws error for distribution strategy of GPU as well. Can check here - #740
Thanks and hope to see a fix soon.
evanatyourservice commentedon Aug 29, 2021
For what it's worth, I was able to get it to work with data_init=False, because the problem in your error code,
self._naked_clone_layer.set_weights(self.layer.get_weights())
, is only called if data_init=True.evanatyourservice commentedon Aug 29, 2021
I actually thought this might be a problem before I even tried tfa weightnorm on TPU... The real question for this issue is, how to do weightnorm paper's data-dependent weight init in a tf distributed strategy?
dharmanibc commentedon Mar 20, 2023
The problem still persist for both GPU and TPU. Even a simple sample-code shows the same error.
x = np.random.rand(1, 10, 10, 1)
dense = addon_layers.WeightNormalization(tf.keras.layers.Dense(10),
data_init=False)
y = dense(x)
y.shape
TypeError Traceback (most recent call last)
in
7
8 x = np.random.rand(1, 10, 10, 1)
----> 9 dense = addon_layers.WeightNormalization(tf.keras.layers.Dense(10), data_init=False)
10 y = dense(x)
11 y.shape
2 frames
/usr/local/lib/python3.9/dist-packages/tensorflow_addons/layers/wrappers.py in init(self, layer, data_init, **kwargs)
57
58 @TypeChecked
---> 59 def init(self, layer: tf.keras.layers, data_init: bool = True, **kwargs):
60 super().init(layer, **kwargs)
61 self.data_init = data_init
/usr/local/lib/python3.9/dist-packages/typeguard/_functions.py in check_argument_types(memo)
111 value = memo.arguments[argname]
112 try:
--> 113 check_type_internal(value, expected_type, memo=memo)
114 except TypeCheckError as exc:
115 qualname = qualified_name(value, add_class_prefix=True)
/usr/local/lib/python3.9/dist-packages/typeguard/_checkers.py in check_type_internal(value, annotation, memo)
668 return
669
--> 670 if not isinstance(value, origin_type):
671 raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
672
TypeError: isinstance() arg 2 must be a type or tuple of types