Skip to content

Commit edbf8f5

Browse files
authored
Update keras3 Softmax mask handling to be more numerically robust. (#21850)
* Update keras3 Softmax mask handling to be more numerically robust. * Fix formatting
1 parent 032528b commit edbf8f5

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

keras/src/layers/activations/softmax.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,15 @@ def __init__(self, axis=-1, **kwargs):
5252

5353
def call(self, inputs, mask=None):
5454
if mask is not None:
55-
adder = (
56-
1.0 - backend.cast(mask, inputs.dtype)
57-
) * _large_negative_number(inputs.dtype)
58-
inputs += adder
55+
# We keep the positions where the mask is True or > 0.5, and set the
56+
# other (masked) positions to -1e.9.
57+
if backend.standardize_dtype(mask.dtype) != "bool":
58+
mask = backend.numpy.greater(
59+
mask, backend.cast(0.5, dtype=mask.dtype)
60+
)
61+
inputs = backend.numpy.where(
62+
mask, inputs, _large_negative_number(inputs.dtype)
63+
)
5964
if isinstance(self.axis, (tuple, list)):
6065
if len(self.axis) > 1:
6166
outputs = backend.numpy.exp(

0 commit comments

Comments
 (0)