Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions keras/src/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,30 @@ def _load_state(
error_msgs.pop(id(saveable))


def _get_container_item_name(saveable):
"""Get the name to use for a saveable in a container.

Uses `saveable.name` for stable topology-based matching. This ensures
that layers with the same name but different classes (e.g., a custom
LSTM subclass vs vanilla LSTM) map to the same path in the weights
file. Falls back to the class name for saveables without a name.
"""
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return saveable.name
return naming.to_snake_case(saveable.__class__.__name__)
Comment on lines +877 to +879
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The _get_container_item_name function returns saveable.name without any sanitization. This name is then used to construct file paths for saving and loading model assets via DiskIOStore. Since saveable.name can be controlled by a malicious model configuration file (config.json), an attacker can use path traversal sequences (e.g., ../) or absolute paths to read or write files outside the intended temporary directory. For example, a layer named ../../../../etc/passwd would cause DiskIOStore.get to return /etc/passwd, which is then passed to the layer's load_assets method, potentially leading to an arbitrary file read.

Suggested change
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return saveable.name
return naming.to_snake_case(saveable.__class__.__name__)
if hasattr(saveable, "name") and isinstance(saveable.name, str):
return naming.to_snake_case(saveable.name)
return naming.to_snake_case(saveable.__class__.__name__)



def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return True
Comment on lines +882 to +890
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _store_has_path has a couple of potential issues that could lead to incorrect behavior, especially with sharded weights.

  1. Incorrect type checking order: ShardedH5IOStore is a subclass of H5IOStore. The current check isinstance(weights_store, H5IOStore) will evaluate to True for a ShardedH5IOStore instance, causing it to execute logic that is incorrect for sharded stores as it only checks the current shard. The check for ShardedH5IOStore should be performed before the check for H5IOStore.

  2. Unsafe fallback: The function returns True as a fallback for any store type that is not H5IOStore or NpzIOStore. This is unsafe because it assumes the path exists in any unknown or unhandled store type. It would be safer to return False.

Here is a suggested implementation that addresses these points:

Suggested change
def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return True
def _store_has_path(weights_store, path):
"""Check if a path exists in the weights store."""
if weights_store is None or not path:
return False
# `ShardedH5IOStore` must be checked before `H5IOStore` due to inheritance.
if isinstance(weights_store, ShardedH5IOStore):
weight_map = weights_store.sharding_config["weight_map"]
# The path in the weight map is typically `/{path}/vars`.
return path in weight_map or f"/{path}/vars" in weight_map
if isinstance(weights_store, H5IOStore):
return path in weights_store.h5_file
if isinstance(weights_store, NpzIOStore):
return path in weights_store.contents
return False



def _save_container_state(
container, weights_store, assets_store, inner_path, visited_saveables
):
Expand All @@ -877,10 +901,7 @@ def _save_container_state(

for saveable in container:
if isinstance(saveable, KerasSaveable):
# Do NOT address the saveable via `saveable.name`, since
# names are usually autogenerated and thus not reproducible
# (i.e. they may vary across two instances of the same model).
name = naming.to_snake_case(saveable.__class__.__name__)
name = _get_container_item_name(saveable)
if name in used_names:
used_names[name] += 1
name = f"{name}_{used_names[name]}"
Expand Down Expand Up @@ -908,22 +929,43 @@ def _load_container_state(
from keras.src.saving.keras_saveable import KerasSaveable

used_names = {}
used_class_names = {}
if isinstance(container, dict):
container = list(container.values())

for saveable in container:
if isinstance(saveable, KerasSaveable):
name = naming.to_snake_case(saveable.__class__.__name__)
name = _get_container_item_name(saveable)
if name in used_names:
used_names[name] += 1
name = f"{name}_{used_names[name]}"
else:
used_names[name] = 0
candidate_path = file_utils.join(inner_path, name).replace(
"\\", "/"
)

# Backward compatibility: if the name-based path doesn't
# exist in the store, fall back to class-name-based path
# (the format used before this fix).
if not _store_has_path(weights_store, candidate_path):
class_name = naming.to_snake_case(saveable.__class__.__name__)
if class_name in used_class_names:
used_class_names[class_name] += 1
class_name = f"{class_name}_{used_class_names[class_name]}"
else:
used_class_names[class_name] = 0
fallback_path = file_utils.join(inner_path, class_name).replace(
"\\", "/"
)
if _store_has_path(weights_store, fallback_path):
candidate_path = fallback_path

_load_state(
saveable,
weights_store,
assets_store,
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
inner_path=candidate_path,
skip_mismatch=skip_mismatch,
visited_saveables=visited_saveables,
failed_saveables=failed_saveables,
Expand Down
38 changes: 38 additions & 0 deletions keras/src/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,44 @@ def is_remote_path(path):
model.save_weights(temp_filepath)
model.load_weights(temp_filepath)

def test_custom_subclass_weight_loading(self):
"""Test that weights saved from a custom subclass can be loaded
into a model using the base class, and vice versa, when the layer
names match (issue #20322)."""

@keras.saving.register_keras_serializable(package="custom_lstm_test")
class CustomLSTM(keras.layers.LSTM):
pass

# Build model with custom LSTM subclass
inputs_a = keras.Input(shape=(10, 1))
lstm_a = CustomLSTM(32, name="my_lstm")
dense_a = keras.layers.Dense(1, name="output")
model_a = keras.Model(
inputs_a, dense_a(lstm_a(inputs_a)), name="model_a"
)

# Build model with vanilla LSTM (same layer name)
inputs_b = keras.Input(shape=(10, 1))
lstm_b = keras.layers.LSTM(32, name="my_lstm")
dense_b = keras.layers.Dense(1, name="output")
model_b = keras.Model(
inputs_b, dense_b(lstm_b(inputs_b)), name="model_b"
)

# Save weights from custom subclass model
temp_filepath = os.path.join(
self.get_temp_dir(), "custom_lstm.weights.h5"
)
model_a.save_weights(temp_filepath)

# Load into vanilla LSTM model
model_b.load_weights(temp_filepath)

# Verify predictions match
x = np.random.random((1, 10, 1)).astype("float32")
self.assertAllClose(model_a(x), model_b(x))
Comment on lines +1182 to +1184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test docstring mentions testing loading 'and vice versa', but the test currently only covers saving from a custom subclass and loading into the base class. To make the test more comprehensive and match its description, consider adding a test for the other direction: saving from the base LSTM and loading into the CustomLSTM subclass.

        # Verify predictions match
        x = np.random.random((1, 10, 1)).astype("float32")
        self.assertAllClose(model_a(x), model_b(x))

        # Test the other direction: save from vanilla, load into subclass
        temp_filepath_2 = os.path.join(
            self.get_temp_dir(), "vanilla_lstm.weights.h5"
        )
        # Re-build model_a to reset its weights before loading
        inputs_a_2 = keras.Input(shape=(10, 1))
        lstm_a_2 = CustomLSTM(32, name="my_lstm")
        dense_a_2 = keras.layers.Dense(1, name="output")
        model_a_2 = keras.Model(
            inputs_a_2, dense_a_2(lstm_a_2(inputs_a_2)), name="model_a"
        )

        model_b.save_weights(temp_filepath_2)
        model_a_2.load_weights(temp_filepath_2)

        # Verify predictions match
        self.assertAllClose(model_a_2(x), model_b(x))



class SavingH5IOStoreTest(testing.TestCase):
def test_h5_io_store_basics(self):
Expand Down
Loading