Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions keras/src/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,20 @@ def _save_container_state(
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
visited_saveables=visited_saveables,
)
elif isinstance(saveable, (list, dict, tuple, set)):
name = "_container"
if name in used_names:
used_names[name] += 1
name = f"{name}_{used_names[name]}"
else:
used_names[name] = 0
Comment on lines +897 to +902
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for generating a unique name for the container (_container, _container_1, etc.) is duplicated here and in _load_container_state. Consider extracting this logic into a small helper function to improve maintainability and reduce redundancy. This aligns with the principle of keeping APIs modular and automating repetitive tasks (Repository Style Guide, lines 51, 124).

            name = "_container"
            name = _get_unique_container_name(name, used_names)

_save_container_state(
saveable,
weights_store,
assets_store,
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
visited_saveables=visited_saveables,
)


def _load_container_state(
Expand Down Expand Up @@ -929,6 +943,23 @@ def _load_container_state(
failed_saveables=failed_saveables,
error_msgs=error_msgs,
)
elif isinstance(saveable, (list, dict, tuple, set)):
name = "_container"
if name in used_names:
used_names[name] += 1
name = f"{name}_{used_names[name]}"
else:
used_names[name] = 0
Comment on lines +947 to +952
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block duplicates the name generation logic found in _save_container_state. Extracting this into a helper function would make the code cleaner and easier to maintain. This aligns with the principle of keeping APIs modular and automating repetitive tasks (Repository Style Guide, lines 51, 124).

            name = "_container"
            name = _get_unique_container_name(name, used_names)

_load_container_state(
saveable,
weights_store,
assets_store,
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
skip_mismatch=skip_mismatch,
visited_saveables=visited_saveables,
failed_saveables=failed_saveables,
error_msgs=error_msgs,
)


class DiskIOStore:
Expand Down
32 changes: 32 additions & 0 deletions keras/src/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,38 @@ def test_bidirectional_lstm_saving(self):
out = new_model(x)
self.assertAllClose(ref_out, out)

def test_nested_list_layer_saving(self):
"""Test that layers stored in nested lists are saved/loaded."""

@keras.saving.register_keras_serializable(package="test")
class NestedBlockModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.blocks = [
[keras.layers.Dense(8), keras.layers.Dense(8)],
[keras.layers.Dense(8), keras.layers.Dense(8)],
]
self.out_layer = keras.layers.Dense(2)

def call(self, x):
for block in self.blocks:
for layer in block:
x = layer(x)
return self.out_layer(x)

def get_config(self):
return super().get_config()

model = NestedBlockModel()
x = np.random.random((2, 4))
ref_out = model(x)

temp_filepath = os.path.join(self.get_temp_dir(), "nested_list.keras")
model.save(temp_filepath)
new_model = keras.saving.load_model(temp_filepath)
out = new_model(x)
self.assertAllClose(ref_out, out)

def test_remove_weights_only_saving_and_loading(self):
def is_remote_path(path):
return True
Expand Down