-
Notifications
You must be signed in to change notification settings - Fork 234
Open
Labels
Description
On iNNvestigate v2.0.1, creating an analyzer inheriting from AnalyzerNetworkBase errors when the model contains a BatchNormalization layer, e.g.:
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50]
This might be due to batch normalisation layers keeping moving averages of the mean and standard deviation of the training data, causing problems with the Keras history when reversing the computational graph in iNNvestigate's create_analyzer_model.
Minimal example reproducing the issue
import numpy as np
import tensorflow as tf
from keras.layers import BatchNormalization, Dense
from keras.models import Sequential
import innvestigate
tf.compat.v1.disable_eager_execution()
input_shape = (50,)
x = np.random.rand(100, *input_shape)
y = np.random.rand(100, 2)
model1 = Sequential()
model1.add(Dense(10, input_shape=input_shape))
model1.add(Dense(2))
model2 = Sequential()
model2.add(Dense(10, input_shape=input_shape))
model2.add(BatchNormalization())
model2.add(Dense(2))
def run_analysis(model):
model.compile(optimizer="adam", loss="mse")
model.fit(x, y, epochs=10, verbose=0)
analyzer = innvestigate.create_analyzer("gradient", model)
analyzer.analyze(x)
print("Model without BatchNormalization:") # passes
run_analysis(model1)
print("Model with BatchNormalization:") # errors
run_analysis(model2)Full stacktrace
Model with BatchNormalization:
/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py:1480: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
ret = tf_session.TF_SessionRunCallable(self._session._session,
Traceback (most recent call last):
File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 35, in <module>
run_analysis(model2)
File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 29, in run_analysis
analyzer.analyze(x)
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 250, in analyze
self.create_analyzer_model()
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 196, in create_analyzer_model
self._analyzer_model = kmodels.Model(
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper
result = method(self, *args, **kwargs)
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 146, in __init__
self._init_graph_network(inputs, outputs)
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper
result = method(self, *args, **kwargs)
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 181, in _init_graph_network
base_layer_utils.create_keras_history(self._nested_outputs)
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 175, in create_keras_history
_, created_layers = _create_keras_history_helper(tensors, set(), [])
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
processed_ops, created_layers = _create_keras_history_helper(
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
processed_ops, created_layers = _create_keras_history_helper(
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
processed_ops, created_layers = _create_keras_history_helper(
[Previous line repeated 3 more times]
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 251, in _create_keras_history_helper
constants[i] = backend.function([], op_input)([])
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/backend.py", line 4275, in __call__
fetched = self._callable_fn(*array_vals,
File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py", line 1480, in __call__
ret = tf_session.TF_SessionRunCallable(self._session._session,
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50]
[[{{node dense_2_input}}]]