From 13f54e8bf2c91b75f374afee39997644aa9673cd Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 5 Mar 2026 16:31:11 +0530 Subject: [PATCH 1/2] Fix nested list/dict layers not being saved/loaded correctly When a layer stores sub-layers in nested containers (e.g. self.blocks = [[Dense, Dense], [Dense, Dense]]), the save/load functions silently skipped inner containers because _save_container_state and _load_container_state only handled KerasSaveable items, not nested list/dict/tuple/set. Add recursive handling of nested containers with index-based namespacing (_container, _container_1, etc.) to avoid HDF5 path collisions. Fixes https://github.com/keras-team/keras/issues/20598 --- keras/src/saving/saving_lib.py | 31 ++++++++++++++++++++++++++ keras/src/saving/saving_lib_test.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index e73390c28ea3..a0f1b4c5e2b1 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -893,6 +893,20 @@ def _save_container_state( inner_path=file_utils.join(inner_path, name).replace("\\", "/"), visited_saveables=visited_saveables, ) + elif isinstance(saveable, (list, dict, tuple, set)): + name = "_container" + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _save_container_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + visited_saveables=visited_saveables, + ) def _load_container_state( @@ -929,6 +943,23 @@ def _load_container_state( failed_saveables=failed_saveables, error_msgs=error_msgs, ) + elif isinstance(saveable, (list, dict, tuple, set)): + name = "_container" + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _load_container_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) class DiskIOStore: diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 59f7c3473aed..67e7fb5c123a 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -1132,6 +1132,40 @@ def test_bidirectional_lstm_saving(self): out = new_model(x) self.assertAllClose(ref_out, out) + def test_nested_list_layer_saving(self): + """Test that layers stored in nested lists are saved/loaded.""" + + @keras.saving.register_keras_serializable(package="test") + class NestedBlockModel(keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.blocks = [ + [keras.layers.Dense(8), keras.layers.Dense(8)], + [keras.layers.Dense(8), keras.layers.Dense(8)], + ] + self.out_layer = keras.layers.Dense(2) + + def call(self, x): + for block in self.blocks: + for layer in block: + x = layer(x) + return self.out_layer(x) + + def get_config(self): + return super().get_config() + + model = NestedBlockModel() + x = np.random.random((2, 4)) + ref_out = model(x) + + temp_filepath = os.path.join( + self.get_temp_dir(), "nested_list.keras" + ) + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + out = new_model(x) + self.assertAllClose(ref_out, out) + def test_remove_weights_only_saving_and_loading(self): def is_remote_path(path): return True From 92c8aa5fe83e22c3ba51c52dd72c7e4a58c8e5b4 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 5 Mar 2026 17:25:39 +0530 Subject: [PATCH 2/2] style: run pre-commit hooks (ruff formatting) --- keras/src/saving/saving_lib_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 67e7fb5c123a..1e8d097886b5 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -1158,9 +1158,7 @@ def get_config(self): x = np.random.random((2, 4)) ref_out = model(x) - temp_filepath = os.path.join( - self.get_temp_dir(), "nested_list.keras" - ) + temp_filepath = os.path.join(self.get_temp_dir(), "nested_list.keras") model.save(temp_filepath) new_model = keras.saving.load_model(temp_filepath) out = new_model(x)