Skip to content

Commit 5b79b66

Browse files
timothyn617KfacJaxDev
authored andcommitted
Some refactoring of methods to enable additional customization of BaseTrainer, including how input data to models are sharded.
PiperOrigin-RevId: 766568928
1 parent f9b8b1b commit 5b79b66

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

examples/losses.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def sigmoid_cross_entropy(
5353
register_loss: bool = True,
5454
extra_registration_kwargs: dict[str, Any] | None = None,
5555
registration_module: types.ModuleType = kfac_jax,
56+
mask: Array | None = None,
5657
) -> Array:
5758
"""Sigmoid cross-entropy loss."""
5859
extra_registration_kwargs = extra_registration_kwargs or {}
@@ -70,7 +71,12 @@ def sigmoid_cross_entropy(
7071

7172
log_1p = jnp.log1p(jnp.exp(neg_abs_logits))
7273

73-
return weight * jnp.add(relu_logits - logits * labels, log_1p)
74+
if mask is None:
75+
mask = 1.0
76+
else:
77+
assert mask.shape == labels.shape
78+
79+
return weight * mask * jnp.add(relu_logits - logits * labels, log_1p)
7480

7581

7682
def softmax_cross_entropy(

0 commit comments

Comments
 (0)