Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,26 @@ def output_shape(self):
return output_shapes[0]
return output_shapes

def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)
def _assert_input_compatibility(self, inputs):
if isinstance(self._inputs_struct, dict) and not isinstance(
inputs, dict
):
# Allow list/tuple with matching length (positional matching)
if isinstance(inputs, (list, tuple)):
if len(inputs) == len(self._inputs_struct):
return super(Model, self)._assert_input_compatibility(
inputs
)
keys = list(self._inputs_struct.keys())
raise ValueError(
f'Model "{self.name}" expects inputs as a `dict` with '
f"the following keys: {keys}. Instead received "
f"{type(inputs).__name__}. Pass your data as "
"`model.fit({"
+ ", ".join(f"'{k}': ..." for k in keys)
+ "}, ...)`."
)
Comment on lines +226 to +234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability, you can construct the example string for the error message in a separate variable. This avoids using + for string concatenation within the f-string, making the code cleaner.

            keys = list(self._inputs_struct.keys())
            example_fit_kwargs = ", ".join(f"'{k}': ..." for k in keys)
            raise ValueError(
                f'Model "{self.name}" expects inputs as a `dict` with '
                f"the following keys: {keys}. Instead received "
                f"{type(inputs).__name__}. Pass your data as "
                f"`model.fit({{{example_fit_kwargs}}}, ...)`."
            )

return super(Model, self)._assert_input_compatibility(inputs)

def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False):
try:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_bad_input_spec(self):

# Dict input
model = Functional({"a": input_a, "b": input_b}, outputs)
with self.assertRaisesRegex(ValueError, "expects 2 input"):
with self.assertRaisesRegex(ValueError, r"expects inputs as a `dict`"):
model(np.zeros((2, 3)))
with self.assertRaisesRegex(
ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)"
Expand Down