diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index e73390c28ea3..db7f702a87ec 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -866,6 +866,30 @@ def _load_state( error_msgs.pop(id(saveable)) +def _get_container_item_name(saveable): + """Get the name to use for a saveable in a container. + + Uses `saveable.name` for stable topology-based matching. This ensures + that layers with the same name but different classes (e.g., a custom + LSTM subclass vs vanilla LSTM) map to the same path in the weights + file. Falls back to the class name for saveables without a name. + """ + if hasattr(saveable, "name") and isinstance(saveable.name, str): + return saveable.name + return naming.to_snake_case(saveable.__class__.__name__) + + +def _store_has_path(weights_store, path): + """Check if a path exists in the weights store.""" + if weights_store is None or not path: + return False + if isinstance(weights_store, H5IOStore): + return path in weights_store.h5_file + if isinstance(weights_store, NpzIOStore): + return path in weights_store.contents + return True + + def _save_container_state( container, weights_store, assets_store, inner_path, visited_saveables ): @@ -877,10 +901,7 @@ def _save_container_state( for saveable in container: if isinstance(saveable, KerasSaveable): - # Do NOT address the saveable via `saveable.name`, since - # names are usually autogenerated and thus not reproducible - # (i.e. they may vary across two instances of the same model). - name = naming.to_snake_case(saveable.__class__.__name__) + name = _get_container_item_name(saveable) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" @@ -908,22 +929,43 @@ def _load_container_state( from keras.src.saving.keras_saveable import KerasSaveable used_names = {} + used_class_names = {} if isinstance(container, dict): container = list(container.values()) for saveable in container: if isinstance(saveable, KerasSaveable): - name = naming.to_snake_case(saveable.__class__.__name__) + name = _get_container_item_name(saveable) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" else: used_names[name] = 0 + candidate_path = file_utils.join(inner_path, name).replace( + "\\", "/" + ) + + # Backward compatibility: if the name-based path doesn't + # exist in the store, fall back to class-name-based path + # (the format used before this fix). + if not _store_has_path(weights_store, candidate_path): + class_name = naming.to_snake_case(saveable.__class__.__name__) + if class_name in used_class_names: + used_class_names[class_name] += 1 + class_name = f"{class_name}_{used_class_names[class_name]}" + else: + used_class_names[class_name] = 0 + fallback_path = file_utils.join(inner_path, class_name).replace( + "\\", "/" + ) + if _store_has_path(weights_store, fallback_path): + candidate_path = fallback_path + _load_state( saveable, weights_store, assets_store, - inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + inner_path=candidate_path, skip_mismatch=skip_mismatch, visited_saveables=visited_saveables, failed_saveables=failed_saveables, diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 59f7c3473aed..862a930fb0a1 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -1145,6 +1145,44 @@ def is_remote_path(path): model.save_weights(temp_filepath) model.load_weights(temp_filepath) + def test_custom_subclass_weight_loading(self): + """Test that weights saved from a custom subclass can be loaded + into a model using the base class, and vice versa, when the layer + names match (issue #20322).""" + + @keras.saving.register_keras_serializable(package="custom_lstm_test") + class CustomLSTM(keras.layers.LSTM): + pass + + # Build model with custom LSTM subclass + inputs_a = keras.Input(shape=(10, 1)) + lstm_a = CustomLSTM(32, name="my_lstm") + dense_a = keras.layers.Dense(1, name="output") + model_a = keras.Model( + inputs_a, dense_a(lstm_a(inputs_a)), name="model_a" + ) + + # Build model with vanilla LSTM (same layer name) + inputs_b = keras.Input(shape=(10, 1)) + lstm_b = keras.layers.LSTM(32, name="my_lstm") + dense_b = keras.layers.Dense(1, name="output") + model_b = keras.Model( + inputs_b, dense_b(lstm_b(inputs_b)), name="model_b" + ) + + # Save weights from custom subclass model + temp_filepath = os.path.join( + self.get_temp_dir(), "custom_lstm.weights.h5" + ) + model_a.save_weights(temp_filepath) + + # Load into vanilla LSTM model + model_b.load_weights(temp_filepath) + + # Verify predictions match + x = np.random.random((1, 10, 1)).astype("float32") + self.assertAllClose(model_a(x), model_b(x)) + class SavingH5IOStoreTest(testing.TestCase): def test_h5_io_store_basics(self):