Skip to content

Commit 1ded22f

Browse files
james-martensKfacJaxDev
authored andcommitted
Removing broken "mask" feature from sigmoid_cross_entropy loss in examples code.
PiperOrigin-RevId: 845765776
1 parent 9aaa92b commit 1ded22f

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

examples/losses.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def sigmoid_cross_entropy(
5353
register_loss: bool = True,
5454
extra_registration_kwargs: dict[str, Any] | None = None,
5555
registration_module: types.ModuleType = kfac_jax,
56-
mask: Array | None = None,
5756
) -> Array:
5857
"""Sigmoid cross-entropy loss."""
5958
extra_registration_kwargs = extra_registration_kwargs or {}
@@ -71,12 +70,7 @@ def sigmoid_cross_entropy(
7170

7271
log_1p = jnp.log1p(jnp.exp(neg_abs_logits))
7372

74-
if mask is None:
75-
mask = 1.0
76-
else:
77-
assert mask.shape == labels.shape
78-
79-
return weight * mask * jnp.add(relu_logits - logits * labels, log_1p)
73+
return weight * jnp.add(relu_logits - logits * labels, log_1p)
8074

8175

8276
def softmax_cross_entropy(

0 commit comments

Comments
 (0)