Skip to content

Fix #20322 — Weight save/load fails when model uses a custom subclass of a built-in layer (e.g. LSTM)#22360

Draft
pctablet505 wants to merge 2 commits intokeras-team:masterfrom
pctablet505:fix/20322-lstm-weight-loading
Draft

Fix #20322 — Weight save/load fails when model uses a custom subclass of a built-in layer (e.g. LSTM)#22360
pctablet505 wants to merge 2 commits intokeras-team:masterfrom
pctablet505:fix/20322-lstm-weight-loading

Conversation

@pctablet505
Copy link
Collaborator

Fixes: #20322
This pull request improves the robustness and compatibility of Keras model weight saving and loading, especially when dealing with custom layer subclasses that share names with built-in layers. The changes ensure that weights can be reliably transferred between models with matching layer names, even if the layer classes differ, and add comprehensive tests to cover this scenario.

Core logic improvements:

  • Added the _get_container_item_name helper to consistently use the layer's name for topology-based matching when saving/loading weights, with fallback to the class name for unnamed saveables. This ensures compatibility between custom and base class layers with the same name.
  • Updated _save_container_state and _load_container_state to use the new naming logic, and in _load_container_state, added a fallback mechanism: if the new name-based path isn't found in the weights store, it tries the legacy class-name-based path for backward compatibility. [1] [2]

Testing improvements:

Problem

When saving weights, Keras walks the layer tree and stores each sublayer keyed by to_snake_case(cls.__name__). If the saved model uses class CustomLSTM(LSTM), the key written is custom_lstm. When loading into a model that uses the base LSTM class (same layer name, same weights), the key expected is lstm — mismatch, and weights are silently skipped.

The same problem occurs in reverse: saving with the base class and loading into a subclass.

Root Cause

The original code comment acknowledged this explicitly but chose class-name-based keys to avoid autogenerated-name drift between instances. However this choice breaks cross-class weight transfer even when the layer names are identical and the weight shapes match exactly.

Fix

Introduce _get_container_item_name(saveable): prefer saveable.name (the user-assigned or Keras-deduped string name) as the key, falling back to the class-name approach only when no string name is available. On load, add backward-compatible fallback: if the name-based path doesn't exist in the store, retry with the old class-name-based path. This preserves compatibility with weights files saved before this fix.

Files Changed

  • keras/src/saving/saving_lib.py_get_container_item_name(), _store_has_path(), updated _save_container_state and _load_container_state
  • keras/src/saving/saving_lib_test.py — regression test: save from CustomLSTM, load into vanilla LSTM, assert predictions match

…state serialization

The _save_container_state and _load_container_state functions used
class names to construct HDF5 paths for container layers. This meant
saving weights from a custom LSTM subclass (e.g. MyCustomLSTM) and
loading into a vanilla LSTM model would fail even when layer names
matched, because the HDF5 paths differed (my_custom_lstm vs lstm).

Use saveable.name (the layer name) instead of the class name for
topology-based path matching. Add backward-compatible fallback to
class-name paths so files saved with the old format still load.

Fixes keras-team#20322
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where Keras model weight saving and loading would fail or silently skip weights when models incorporated custom subclasses of built-in layers, such as LSTM. The changes introduce a more robust naming strategy for saveable items, prioritizing explicit layer names for consistent weight mapping, and include a backward-compatible loading mechanism. This significantly improves the reliability of weight transfer across models with similar architectures but differing layer implementations, resolving issue #20322.

Highlights

  • Improved Weight Saving/Loading: Enhanced the mechanism for saving and loading model weights to ensure compatibility when using custom subclasses of built-in Keras layers (e.g., CustomLSTM vs. LSTM).
  • Consistent Naming Strategy: Introduced a new helper function _get_container_item_name that prioritizes a layer's name for topology-based matching during weight operations, falling back to the class name only when a specific name is unavailable.
  • Backward Compatibility: Implemented a fallback mechanism in _load_container_state to try the legacy class-name-based path if the new name-based path is not found, ensuring compatibility with older weight files.
  • Comprehensive Testing: Added a dedicated regression test (test_custom_subclass_weight_loading) to validate that weights can be correctly transferred between models using custom subclasses and their base classes when layer names match.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/saving/saving_lib.py
    • Added _get_container_item_name to determine the appropriate name for saveables, preferring saveable.name over class name.
    • Added _store_has_path to check for the existence of a path within a weights store.
    • Modified _save_container_state to utilize _get_container_item_name for generating saveable paths.
    • Updated _load_container_state to use _get_container_item_name and incorporate a backward-compatible fallback to class-name-based paths if the name-based path is not found.
  • keras/src/saving/saving_lib_test.py
    • Added test_custom_subclass_weight_loading to verify weight loading between models with custom layer subclasses and base classes.
Activity
  • No human activity (comments, reviews, etc.) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable fix for model weight saving and loading when using custom subclasses of built-in layers, improving compatibility by prioritizing the layer's name over its class for key generation and including a fallback for backward compatibility. However, it introduces a critical path traversal vulnerability by using unsanitized layer names to construct file paths during model saving and loading, which could allow a malicious model to read or write arbitrary files on the host system. Sanitizing layer names before using them as path components is strongly recommended. Additionally, there's a potential bug in the new _store_has_path helper function that could affect sharded weights, and the new test case could be more comprehensive by covering the reverse scenario.

Comment on lines +877 to +879
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return saveable.name
return naming.to_snake_case(saveable.__class__.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The _get_container_item_name function returns saveable.name without any sanitization. This name is then used to construct file paths for saving and loading model assets via DiskIOStore. Since saveable.name can be controlled by a malicious model configuration file (config.json), an attacker can use path traversal sequences (e.g., ../) or absolute paths to read or write files outside the intended temporary directory. For example, a layer named ../../../../etc/passwd would cause DiskIOStore.get to return /etc/passwd, which is then passed to the layer's load_assets method, potentially leading to an arbitrary file read.

Suggested change
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return saveable.name
return naming.to_snake_case(saveable.__class__.__name__)
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return naming.to_snake_case(saveable.name)
return naming.to_snake_case(saveable.__class__.__name__)

Comment on lines +882 to +890
def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _store_has_path has a couple of potential issues that could lead to incorrect behavior, especially with sharded weights.

  1. Incorrect type checking order: ShardedH5IOStore is a subclass of H5IOStore. The current check isinstance(weights_store, H5IOStore) will evaluate to True for a ShardedH5IOStore instance, causing it to execute logic that is incorrect for sharded stores as it only checks the current shard. The check for ShardedH5IOStore should be performed before the check for H5IOStore.

  2. Unsafe fallback: The function returns True as a fallback for any store type that is not H5IOStore or NpzIOStore. This is unsafe because it assumes the path exists in any unknown or unhandled store type. It would be safer to return False.

Here is a suggested implementation that addresses these points:

Suggested change
def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return True
def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
# `ShardedH5IOStore` must be checked before `H5IOStore` due to inheritance.
if isinstance(weights_store, ShardedH5IOStore):
weight_map = weights_store.sharding_config["weight_map"]
# The path in the weight map is typically `/{path}/vars`.
return path in weight_map or f"/{path}/vars" in weight_map
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return False

Comment on lines +1182 to +1184
# Verify predictions match
x = np.random.random((1, 10, 1)).astype("float32")
self.assertAllClose(model_a(x), model_b(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test docstring mentions testing loading 'and vice versa', but the test currently only covers saving from a custom subclass and loading into the base class. To make the test more comprehensive and match its description, consider adding a test for the other direction: saving from the base LSTM and loading into the CustomLSTM subclass.

        # Verify predictions match
        x = np.random.random((1, 10, 1)).astype("float32")
        self.assertAllClose(model_a(x), model_b(x))

        # Test the other direction: save from vanilla, load into subclass
        temp_filepath_2 = os.path.join(
            self.get_temp_dir(), "vanilla_lstm.weights.h5"
        )
        # Re-build model_a to reset its weights before loading
        inputs_a_2 = keras.Input(shape=(10, 1))
        lstm_a_2 = CustomLSTM(32, name="my_lstm")
        dense_a_2 = keras.layers.Dense(1, name="output")
        model_a_2 = keras.Model(
            inputs_a_2, dense_a_2(lstm_a_2(inputs_a_2)), name="model_a"
        )

        model_b.save_weights(temp_filepath_2)
        model_a_2.load_weights(temp_filepath_2)

        # Verify predictions match
        self.assertAllClose(model_a_2(x), model_b(x))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants