Skip to content

Commit b3b588d

Browse files
Merge pull request #25 from nlsfnr:main
PiperOrigin-RevId: 503185031
2 parents 6fa1a7d + b20cec7 commit b3b588d

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

jmp/_src/loss_scale.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import functools
1919
from typing import Tuple, TypeVar, Union
20+
import warnings
2021

2122
import jax
2223
from jax import tree_util
@@ -122,6 +123,10 @@ class DynamicLossScale:
122123
min_loss_scale: jnp.ndarray = dataclasses.field(
123124
default_factory=lambda: np.ones([], np.int32))
124125

126+
def __post_init__(self) -> None:
127+
warn_if_not_floating(self.loss_scale, "loss_scale")
128+
warn_if_not_floating(self.min_loss_scale, "min_loss_scale")
129+
125130
def scale(self, tree: T) -> T:
126131
# usage_logging.log_event(usage_logging.Event.JMP, "DynamicLossScale")
127132
return jax.tree_util.tree_map(lambda x: x * self.loss_scale, tree)
@@ -191,3 +196,28 @@ def select_tree(pred: jnp.ndarray, a: T, b: T) -> T:
191196
"""Selects a pytree based on the given predicate."""
192197
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
193198
return jax.tree_map(functools.partial(jax.lax.select, pred), a, b)
199+
200+
201+
def warn_if_not_floating(x: Union[jnp.ndarray, object], var_name: str) -> None:
202+
"""Produces a warning if the given array does not have a floating type.
203+
204+
This function handles an edgecase where Jax passes in an `object()` to
205+
determine the structure of user defined pytrees during compilation. They
206+
recommend explicitly checking if the array in question has the type `object`.
207+
208+
From the Jax documentation: "The __init__ and __new__ methods of custom
209+
PyTree classes should generally avoid doing any array conversion or other
210+
input validation, or else anticipate and handle these special cases."
211+
212+
See:
213+
https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
214+
215+
Args:
216+
x: Any object.
217+
var_name: A useful name to put in error messages.
218+
"""
219+
if type(x) is object: # pylint: disable=unidiomatic-typecheck
220+
return
221+
x_dtype = jax.eval_shape(lambda: x).dtype
222+
if not jnp.issubdtype(x_dtype, jnp.floating):
223+
warnings.warn(f"Expected floating type for {var_name}, got {x_dtype}")

jmp/_src/loss_scale_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_no_op_loss_scale(self):
3434
("StaticLossScale(2)", jmp.StaticLossScale, 2),
3535
("StaticLossScale(3)", jmp.StaticLossScale, 3),
3636
("StaticLossScale(4)", jmp.StaticLossScale, 4),
37-
("DynamicLossScale(2)", jmp.DynamicLossScale, 2),
38-
("DynamicLossScale(3)", jmp.DynamicLossScale, 3),
39-
("DynamicLossScale(4)", jmp.DynamicLossScale, 4),
37+
("DynamicLossScale(2)", jmp.DynamicLossScale, 2.),
38+
("DynamicLossScale(3)", jmp.DynamicLossScale, 3.),
39+
("DynamicLossScale(4)", jmp.DynamicLossScale, 4.),
4040
)
4141
def test_static_loss_scale(self, cls, scale):
4242
loss_scale = cls(scale)
@@ -98,7 +98,7 @@ def test_dynamic_loss_scale_adjust_reduce_on_non_finite(self, period, factor):
9898
self.assertEqual(loss_scale.period, period)
9999
self.assertEqual(loss_scale.factor, factor)
100100

101-
@parameterized.parameters((20, 2, .3125), (30, 3, .37), (5, 2, 0))
101+
@parameterized.parameters((20, 2, .3125), (30, 3, .37), (5., 2., 0.))
102102
def test_dynamic_loss_scale_explicit_min_loss_scale(self, period, factor,
103103
min_loss_scale):
104104
grads_finite = jnp.bool_(False)
@@ -120,6 +120,17 @@ def test_dynamic_loss_scale_explicit_min_loss_scale(self, period, factor,
120120
def test_dynamic_loss_scale_adjust_requires_scalar_input(self):
121121
pass
122122

123+
def test_dynamic_loss_scale_raises_type_error_on_int_loss_scale(self):
124+
expected_message = "Expected floating type for loss_scale"
125+
with self.assertWarnsRegex(Warning, expected_message):
126+
jmp.DynamicLossScale(jnp.asarray(1, dtype=jnp.int32))
127+
128+
def test_dynamic_loss_scale_raises_type_error_on_int_min_loss_scale(self):
129+
expected_message = "Expected floating type for min_loss_scale"
130+
with self.assertWarnsRegex(Warning, expected_message):
131+
jmp.DynamicLossScale(jnp.asarray(1, dtype=jnp.float32),
132+
min_loss_scale=jnp.asarray(1, dtype=jnp.int32))
133+
123134
@parameterized.parameters(jnp.inf, jnp.nan)
124135
def test_all_finite(self, non_finite):
125136
self.assertTrue(jmp.all_finite(None))

0 commit comments

Comments
 (0)