Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion keras/src/layers/regularization/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,45 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
)
self.rate = rate
self.seed = seed
self.noise_shape = noise_shape
self.noise_shape = self._validate_noise_shape(noise_shape)
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True

self._build_at_init()

def _validate_noise_shape(self, noise_shape):
if noise_shape is None:
return None

if not isinstance(noise_shape, tuple):
try:
noise_shape = tuple(noise_shape)
except TypeError:
raise ValueError(
f"Invalid value received for argument `noise_shape`. "
f"Expected a tuple or list of integers. "
f"Received: noise_shape={noise_shape}"
)
Comment on lines +66 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message here states that noise_shape is expected to be a "tuple or list of integers". However, the code correctly attempts to convert any iterable to a tuple by calling tuple(noise_shape). To avoid potential confusion for users who might pass another valid iterable type (like a set or a generator), it would be more accurate to mention "iterable" in the error message. This change would make the feedback more precise, which is encouraged by the Keras API design guidelines.1

Suggested change
raise ValueError(
f"Invalid value received for argument `noise_shape`. "
f"Expected a tuple or list of integers. "
f"Received: noise_shape={noise_shape}"
)
raise ValueError(
f"Invalid value received for argument `noise_shape`. "
f"Expected an iterable of integers (e.g., a tuple or list). "
f"Received: noise_shape={noise_shape}"
)

Style Guide References

Footnotes

  1. Error messages should be contextual, informative, and actionable. A good error message should clearly and precisely state what was expected to help the user fix the issue.


for i, dim in enumerate(noise_shape):
if dim is not None:
if not isinstance(dim, int):
raise ValueError(
f"Invalid value received for argument `noise_shape`. "
f"Expected all elements to be integers or None. "
f"Received element at index {i}: {dim} (type: {type(dim).__name__})"
)

if dim <= 0:
raise ValueError(
f"Invalid value received for argument `noise_shape`. "
f"Expected all dimensions to be positive integers or None. "
f"Received negative or zero value at index {i}: {dim}"
)

return noise_shape

def call(self, inputs, training=False):
if training and self.rate > 0:
return backend.random.dropout(
Expand Down
Loading