Skip to content

Commit b5bd0af

Browse files
timothyn617KfacJaxDev
authored andcommitted
Updated type annotations for 'weight' (in examples code loss functions) to allow JAX arrays, and removed runtime type check from softmax_cross_entropy function.
PiperOrigin-RevId: 845212033
1 parent 406b21f commit b5bd0af

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

examples/losses.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def l2_regularizer(
4949
def sigmoid_cross_entropy(
5050
logits: Array,
5151
labels: Array,
52-
weight: float = 1.0,
52+
weight: Numeric = 1.0,
5353
register_loss: bool = True,
5454
extra_registration_kwargs: dict[str, Any] | None = None,
5555
registration_module: types.ModuleType = kfac_jax,
@@ -94,10 +94,6 @@ def softmax_cross_entropy(
9494

9595
if register_loss:
9696

97-
if not isinstance(weight, float):
98-
raise NotImplementedError("Non-constant loss weights are not currently "
99-
"supported.")
100-
10197
registration_module.register_softmax_cross_entropy_loss(
10298
logits,
10399
targets=labels,
@@ -147,7 +143,7 @@ def softmax_cross_entropy(
147143
def squared_error(
148144
prediction: Array,
149145
targets: Array,
150-
weight: float = 1.0,
146+
weight: Numeric = 1.0,
151147
register_loss: bool = True,
152148
mask: Array | None = None,
153149
extra_registration_kwargs: dict[str, Any] | None = None,

0 commit comments

Comments
 (0)