diff --git a/keras/src/saving/file_editor.py b/keras/src/saving/file_editor.py index b486590f2132..850a3a664fc5 100644 --- a/keras/src/saving/file_editor.py +++ b/keras/src/saving/file_editor.py @@ -455,6 +455,9 @@ def resave_weights(self, filepath): def _extract_weights_from_store(self, data, metadata=None, inner_path=""): metadata = metadata or {} + # ------------------------------------------------------ + # Collect metadata for this HDF5 group + # ------------------------------------------------------ object_metadata = {} for k, v in data.attrs.items(): object_metadata[k] = v @@ -462,26 +465,110 @@ def _extract_weights_from_store(self, data, metadata=None, inner_path=""): metadata[inner_path] = object_metadata result = collections.OrderedDict() + + # ------------------------------------------------------ + # Iterate over all keys in this HDF5 group + # ------------------------------------------------------ for key in data.keys(): - inner_path = f"{inner_path}/{key}" + # IMPORTANT: + # Never mutate inner_path; use local variable. + current_inner_path = f"{inner_path}/{key}" value = data[key] + + # ------------------------------------------------------ + # CASE 1 — HDF5 GROUP → RECURSE + # ------------------------------------------------------ if isinstance(value, h5py.Group): + # Skip empty groups if len(value) == 0: continue + + # Skip empty "vars" groups if "vars" in value.keys() and len(value["vars"]) == 0: continue - if hasattr(value, "keys"): + # Recurse into "vars" subgroup when present if "vars" in value.keys(): result[key], metadata = self._extract_weights_from_store( - value["vars"], metadata=metadata, inner_path=inner_path + value["vars"], + metadata=metadata, + inner_path=current_inner_path, ) else: + # Recurse normally result[key], metadata = self._extract_weights_from_store( - value, metadata=metadata, inner_path=inner_path + value, + metadata=metadata, + inner_path=current_inner_path, ) - else: - result[key] = value[()] + + continue # finished processing this key + + # ------------------------------------------------------ + # CASE 2 — HDF5 DATASET → SAFE LOADING + # ------------------------------------------------------ + + # Skip any objects that are not proper datasets + if not hasattr(value, "shape") or not hasattr(value, "dtype"): + continue + + shape = value.shape + dtype = value.dtype + + # ------------------------------------------------------ + # Validate SHAPE (avoid malformed / malicious metadata) + # ------------------------------------------------------ + try: + # No negative dims + if any(dim < 0 for dim in shape): + raise ValueError( + "Negative dimension in HDF5 dataset shape." + ) + + # Prevent absurdly high-rank tensors + if len(shape) > 64: + raise ValueError("HDF5 dataset rank too large (>64).") + + # Ensure product does not overflow + num_elems = int(np.prod(shape)) + if num_elems < 0: + raise ValueError( + "Overflow in dataset shape multiplication." + ) + + except Exception as e: + raise ValueError( + "Malformed HDF5 dataset shape encountered in .keras file; " + "refusing to load." + ) from e + + # ------------------------------------------------------ + # Validate TOTAL memory size + # ------------------------------------------------------ + MAX_BYTES = 1 << 30 # 1 GiB + + try: + size_bytes = num_elems * dtype.itemsize + except Exception as e: + raise ValueError( + "Malformed HDF5 dtype encountered in .keras file; " + "refusing to load." + ) from e + + if size_bytes > MAX_BYTES: + raise ValueError( + f"HDF5 dataset too large to load safely " + f"({size_bytes} bytes; limit is {MAX_BYTES})." + ) + + # ------------------------------------------------------ + # SAFE — load dataset (guaranteed ≤ 1 GiB) + # ------------------------------------------------------ + result[key] = value[()] + + # ------------------------------------------------------ + # Return final tree and metadata + # ------------------------------------------------------ return result, metadata def _generate_filepath_info(self, rich_style=False):