Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,44 @@ def __repr__(self):
def __str__(self):
return self.__repr__()

# Attribute names reserved for internal tracking by the Layer class.
# Overriding these in subclasses will silently break weight
# saving/loading and other functionality.
_RESERVED_LAYER_ATTRIBUTES = frozenset(
{
"_layers",
"_metrics",
"_trainable_variables",
"_non_trainable_variables",
"_seed_generators",
}
)

# Internal modules that legitimately reassign reserved attributes.
_RESERVED_ATTR_EXEMPT_MODULES = frozenset(
{
"keras.src.models.sequential",
"keras.src.models.functional",
}
)

def __setattr__(self, name, value):
# Warn if user code reassigns a reserved tracked attribute.
if (
name in self._RESERVED_LAYER_ATTRIBUTES
and hasattr(self, "_tracker")
and hasattr(self, name)
and tracking.is_tracking_enabled()
and type(self).__module__ not in self._RESERVED_ATTR_EXEMPT_MODULES
):
warnings.warn(
f"`{name}` is a reserved attribute in Keras layers and "
"should not be used as a variable name in a Layer "
"subclass. Assigning to it can break weight saving, "
"metric tracking, and other functionality. "
f"Please use a different attribute name.",
stacklevel=2,
)
# Track Variables, Layers, Metrics, SeedGenerators.
name, value = self._setattr_hook(name, value)
if name != "_tracker":
Expand Down
21 changes: 21 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,27 @@ def call(self, x):
):
layer(np.random.random((3, 2)))

def test_reserved_attribute_warning(self):
"""Warn when user code overrides reserved tracked attributes."""

class BadLayer(layers.Layer):
def __init__(self):
super().__init__()
self._layers = [layers.Dense(4)]

with self.assertWarnsRegex(
UserWarning,
"`_layers` is a reserved attribute",
):
BadLayer()

# Internal Keras classes like Sequential should NOT warn.
import warnings

with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
models.Sequential([layers.Dense(4)])

def test_init_after_state_tracking(self):
class MyLayer(layers.Layer):
def __init__(self):
Expand Down
Loading