We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9aaa92b commit 1ded22fCopy full SHA for 1ded22f
examples/losses.py
@@ -53,7 +53,6 @@ def sigmoid_cross_entropy(
53
register_loss: bool = True,
54
extra_registration_kwargs: dict[str, Any] | None = None,
55
registration_module: types.ModuleType = kfac_jax,
56
- mask: Array | None = None,
57
) -> Array:
58
"""Sigmoid cross-entropy loss."""
59
extra_registration_kwargs = extra_registration_kwargs or {}
@@ -71,12 +70,7 @@ def sigmoid_cross_entropy(
71
70
72
log_1p = jnp.log1p(jnp.exp(neg_abs_logits))
73
74
- if mask is None:
75
- mask = 1.0
76
- else:
77
- assert mask.shape == labels.shape
78
-
79
- return weight * mask * jnp.add(relu_logits - logits * labels, log_1p)
+ return weight * jnp.add(relu_logits - logits * labels, log_1p)
80
81
82
def softmax_cross_entropy(
0 commit comments