diff --git a/examples/losses.py b/examples/losses.py index 6785578..da1b301 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -53,6 +53,7 @@ def sigmoid_cross_entropy( register_loss: bool = True, extra_registration_kwargs: dict[str, Any] | None = None, registration_module: types.ModuleType = kfac_jax, + mask: Array | None = None, ) -> Array: """Sigmoid cross-entropy loss.""" extra_registration_kwargs = extra_registration_kwargs or {} @@ -70,7 +71,12 @@ def sigmoid_cross_entropy( log_1p = jnp.log1p(jnp.exp(neg_abs_logits)) - return weight * jnp.add(relu_logits - logits * labels, log_1p) + if mask is None: + mask = 1.0 + else: + assert mask.shape == labels.shape + + return weight * mask * jnp.add(relu_logits - logits * labels, log_1p) def softmax_cross_entropy(