We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f9b8b1b commit 5b79b66Copy full SHA for 5b79b66
examples/losses.py
@@ -53,6 +53,7 @@ def sigmoid_cross_entropy(
53
register_loss: bool = True,
54
extra_registration_kwargs: dict[str, Any] | None = None,
55
registration_module: types.ModuleType = kfac_jax,
56
+ mask: Array | None = None,
57
) -> Array:
58
"""Sigmoid cross-entropy loss."""
59
extra_registration_kwargs = extra_registration_kwargs or {}
@@ -70,7 +71,12 @@ def sigmoid_cross_entropy(
70
71
72
log_1p = jnp.log1p(jnp.exp(neg_abs_logits))
73
- return weight * jnp.add(relu_logits - logits * labels, log_1p)
74
+ if mask is None:
75
+ mask = 1.0
76
+ else:
77
+ assert mask.shape == labels.shape
78
+
79
+ return weight * mask * jnp.add(relu_logits - logits * labels, log_1p)
80
81
82
def softmax_cross_entropy(
0 commit comments