-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Fix #20322 — Weight save/load fails when model uses a custom subclass of a built-in layer (e.g. LSTM) #22360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -866,6 +866,30 @@ def _load_state( | |||||||||||||||||||||||||||||||||||||||||||||||
| error_msgs.pop(id(saveable)) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_container_item_name(saveable): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Get the name to use for a saveable in a container. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Uses `saveable.name` for stable topology-based matching. This ensures | ||||||||||||||||||||||||||||||||||||||||||||||||
| that layers with the same name but different classes (e.g., a custom | ||||||||||||||||||||||||||||||||||||||||||||||||
| LSTM subclass vs vanilla LSTM) map to the same path in the weights | ||||||||||||||||||||||||||||||||||||||||||||||||
| file. Falls back to the class name for saveables without a name. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(saveable, "name") and isinstance(saveable.name, str): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return saveable.name | ||||||||||||||||||||||||||||||||||||||||||||||||
| return naming.to_snake_case(saveable.__class__.__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _store_has_path(weights_store, path): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Check if a path exists in the weights store.""" | ||||||||||||||||||||||||||||||||||||||||||||||||
| if weights_store is None or not path: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(weights_store, H5IOStore): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return path in weights_store.h5_file | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(weights_store, NpzIOStore): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return path in weights_store.contents | ||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+882
to
+890
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of
Here is a suggested implementation that addresses these points:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _save_container_state( | ||||||||||||||||||||||||||||||||||||||||||||||||
| container, weights_store, assets_store, inner_path, visited_saveables | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -877,10 +901,7 @@ def _save_container_state( | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for saveable in container: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(saveable, KerasSaveable): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Do NOT address the saveable via `saveable.name`, since | ||||||||||||||||||||||||||||||||||||||||||||||||
| # names are usually autogenerated and thus not reproducible | ||||||||||||||||||||||||||||||||||||||||||||||||
| # (i.e. they may vary across two instances of the same model). | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = naming.to_snake_case(saveable.__class__.__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = _get_container_item_name(saveable) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if name in used_names: | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_names[name] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = f"{name}_{used_names[name]}" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -908,22 +929,43 @@ def _load_container_state( | |||||||||||||||||||||||||||||||||||||||||||||||
| from keras.src.saving.keras_saveable import KerasSaveable | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| used_names = {} | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_class_names = {} | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(container, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||
| container = list(container.values()) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for saveable in container: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(saveable, KerasSaveable): | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = naming.to_snake_case(saveable.__class__.__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = _get_container_item_name(saveable) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if name in used_names: | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_names[name] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| name = f"{name}_{used_names[name]}" | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_names[name] = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| candidate_path = file_utils.join(inner_path, name).replace( | ||||||||||||||||||||||||||||||||||||||||||||||||
| "\\", "/" | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Backward compatibility: if the name-based path doesn't | ||||||||||||||||||||||||||||||||||||||||||||||||
| # exist in the store, fall back to class-name-based path | ||||||||||||||||||||||||||||||||||||||||||||||||
| # (the format used before this fix). | ||||||||||||||||||||||||||||||||||||||||||||||||
| if not _store_has_path(weights_store, candidate_path): | ||||||||||||||||||||||||||||||||||||||||||||||||
| class_name = naming.to_snake_case(saveable.__class__.__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if class_name in used_class_names: | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_class_names[class_name] += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| class_name = f"{class_name}_{used_class_names[class_name]}" | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| used_class_names[class_name] = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| fallback_path = file_utils.join(inner_path, class_name).replace( | ||||||||||||||||||||||||||||||||||||||||||||||||
| "\\", "/" | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if _store_has_path(weights_store, fallback_path): | ||||||||||||||||||||||||||||||||||||||||||||||||
| candidate_path = fallback_path | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| _load_state( | ||||||||||||||||||||||||||||||||||||||||||||||||
| saveable, | ||||||||||||||||||||||||||||||||||||||||||||||||
| weights_store, | ||||||||||||||||||||||||||||||||||||||||||||||||
| assets_store, | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_path=file_utils.join(inner_path, name).replace("\\", "/"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_path=candidate_path, | ||||||||||||||||||||||||||||||||||||||||||||||||
| skip_mismatch=skip_mismatch, | ||||||||||||||||||||||||||||||||||||||||||||||||
| visited_saveables=visited_saveables, | ||||||||||||||||||||||||||||||||||||||||||||||||
| failed_saveables=failed_saveables, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1145,6 +1145,44 @@ def is_remote_path(path): | |
| model.save_weights(temp_filepath) | ||
| model.load_weights(temp_filepath) | ||
|
|
||
| def test_custom_subclass_weight_loading(self): | ||
| """Test that weights saved from a custom subclass can be loaded | ||
| into a model using the base class, and vice versa, when the layer | ||
| names match (issue #20322).""" | ||
|
|
||
| @keras.saving.register_keras_serializable(package="custom_lstm_test") | ||
| class CustomLSTM(keras.layers.LSTM): | ||
| pass | ||
|
|
||
| # Build model with custom LSTM subclass | ||
| inputs_a = keras.Input(shape=(10, 1)) | ||
| lstm_a = CustomLSTM(32, name="my_lstm") | ||
| dense_a = keras.layers.Dense(1, name="output") | ||
| model_a = keras.Model( | ||
| inputs_a, dense_a(lstm_a(inputs_a)), name="model_a" | ||
| ) | ||
|
|
||
| # Build model with vanilla LSTM (same layer name) | ||
| inputs_b = keras.Input(shape=(10, 1)) | ||
| lstm_b = keras.layers.LSTM(32, name="my_lstm") | ||
| dense_b = keras.layers.Dense(1, name="output") | ||
| model_b = keras.Model( | ||
| inputs_b, dense_b(lstm_b(inputs_b)), name="model_b" | ||
| ) | ||
|
|
||
| # Save weights from custom subclass model | ||
| temp_filepath = os.path.join( | ||
| self.get_temp_dir(), "custom_lstm.weights.h5" | ||
| ) | ||
| model_a.save_weights(temp_filepath) | ||
|
|
||
| # Load into vanilla LSTM model | ||
| model_b.load_weights(temp_filepath) | ||
|
|
||
| # Verify predictions match | ||
| x = np.random.random((1, 10, 1)).astype("float32") | ||
| self.assertAllClose(model_a(x), model_b(x)) | ||
|
Comment on lines
+1182
to
+1184
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test docstring mentions testing loading 'and vice versa', but the test currently only covers saving from a custom subclass and loading into the base class. To make the test more comprehensive and match its description, consider adding a test for the other direction: saving from the base # Verify predictions match
x = np.random.random((1, 10, 1)).astype("float32")
self.assertAllClose(model_a(x), model_b(x))
# Test the other direction: save from vanilla, load into subclass
temp_filepath_2 = os.path.join(
self.get_temp_dir(), "vanilla_lstm.weights.h5"
)
# Re-build model_a to reset its weights before loading
inputs_a_2 = keras.Input(shape=(10, 1))
lstm_a_2 = CustomLSTM(32, name="my_lstm")
dense_a_2 = keras.layers.Dense(1, name="output")
model_a_2 = keras.Model(
inputs_a_2, dense_a_2(lstm_a_2(inputs_a_2)), name="model_a"
)
model_b.save_weights(temp_filepath_2)
model_a_2.load_weights(temp_filepath_2)
# Verify predictions match
self.assertAllClose(model_a_2(x), model_b(x)) |
||
|
|
||
|
|
||
| class SavingH5IOStoreTest(testing.TestCase): | ||
| def test_h5_io_store_basics(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_get_container_item_namefunction returnssaveable.namewithout any sanitization. This name is then used to construct file paths for saving and loading model assets viaDiskIOStore. Sincesaveable.namecan be controlled by a malicious model configuration file (config.json), an attacker can use path traversal sequences (e.g.,../) or absolute paths to read or write files outside the intended temporary directory. For example, a layer named../../../../etc/passwdwould causeDiskIOStore.getto return/etc/passwd, which is then passed to the layer'sload_assetsmethod, potentially leading to an arbitrary file read.