Skip to content

Commit 877fbc6

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fix a bug impacting serialization order for call methods with mutiple inputs.
This bug made `get_config/from_config` and `to_json/from_json` not idempotent for layers that take multiple inputs in their `call` method. When the functional model is constructed and the multiple inputs are passed as positional arguments, the `Node` object has multiple `call_args` in a list. However, by design, serialization only treats the first argument and positional and serializes all the other arguments as keyword arguments. Upon deserialization, the extra arguments are created as keyword arguments. Their order was modified by `tf.nest.flatten`, which sorts dicts by key. This change preserves the order of keyword arguments, regardless of keys. Fixes #795 PiperOrigin-RevId: 686197321
1 parent 2aa84ae commit 877fbc6

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

tf_keras/engine/functional_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,35 @@ def test_multi_input_multi_output_recursion(self):
691691
json_str = model.to_json()
692692
models.model_from_json(json_str)
693693

694+
@test_combinations.generate(
695+
test_combinations.combine(mode=["graph", "eager"])
696+
)
697+
def test_multi_input_layer_call(self):
698+
@object_registration.register_keras_serializable()
699+
class MyLayer(layers.Layer):
700+
def call(self, embedding, query_indices, slot_id, position):
701+
return [embedding, query_indices, slot_id, position]
702+
703+
with self.cached_session():
704+
a = layers.Input(shape=(32,), name="input_a")
705+
b = layers.Input(shape=(32,), name="input_b")
706+
c = layers.Input(shape=(32,), name="input_c")
707+
d = layers.Input(shape=(32,), name="input_d")
708+
709+
output = MyLayer()(a, b, c, d)
710+
model = training_lib.Model(
711+
inputs=[a, b, c, d], outputs=output, name="model"
712+
)
713+
714+
config = model.get_config()
715+
model2 = models.Model.from_config(config)
716+
self.assertEqual(model2.get_config(), config)
717+
718+
model.summary()
719+
json_str = model.to_json()
720+
model2 = models.model_from_json(json_str)
721+
self.assertEqual(model2.to_json(), json_str)
722+
694723
@test_combinations.generate(
695724
test_combinations.combine(mode=["graph", "eager"])
696725
)

tf_keras/engine/node.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
8484
self.call_args = call_args
8585
self.call_kwargs = call_kwargs
8686

87-
# Cached for performance.
87+
# Cached for performance. Put kwargs in order of the call method instead
88+
# of using the sorted key order from `tf.nest.flatten`.
8889
self._flat_arguments = tf.nest.flatten(
89-
(self.call_args, self.call_kwargs)
90+
(self.call_args, self.call_kwargs.values())
9091
)
9192
# Used to avoid expensive `nest` operations in the most common case.
9293
self._single_positional_tensor_passed = (
@@ -176,9 +177,13 @@ def map_arguments(self, tensor_dict):
176177
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
177178
flat_arguments[kt_index] = tensor_dict[kt_id].pop()
178179

180+
# Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a
181+
# list in the original order.
179182
args, kwargs = tf.nest.pack_sequence_as(
180-
(self.call_args, self.call_kwargs), flat_arguments
183+
(self.call_args, self.call_kwargs.values()), flat_arguments
181184
)
185+
# Add the keys to `kwargs` to go from a list to a dict.
186+
kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)}
182187
return args, kwargs
183188

184189
def serialize(self, make_node_key, node_conversion_map):

0 commit comments

Comments
 (0)