Skip to content

Commit 107217e

Browse files
tomhennigancopybara-github
authored andcommitted
Remove two warnings that were showing when using jmp.
Fixes #39. PiperOrigin-RevId: 505639478
1 parent c8aca61 commit 107217e

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

jmp/_src/loss_scale.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class DynamicLossScale:
105105
106106
Typical usage of this class will be something like:
107107
108-
>>> loss_scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
108+
>>> loss_scale = jmp.DynamicLossScale(jnp.asarray(2. ** 15))
109109
>>> for _ in range(num_steps):
110110
... # compute loss
111111
... loss = loss_scale.scale(loss)
@@ -121,7 +121,7 @@ class DynamicLossScale:
121121
period: int = 2000
122122
factor: int = 2
123123
min_loss_scale: jnp.ndarray = dataclasses.field(
124-
default_factory=lambda: np.ones([], np.int32))
124+
default_factory=lambda: np.ones([], np.float32))
125125

126126
def __post_init__(self) -> None:
127127
warn_if_not_floating(self.loss_scale, "loss_scale")
@@ -183,7 +183,7 @@ def adjust(self, grads_finite: jnp.ndarray) -> "DynamicLossScale":
183183

184184
def all_finite(tree) -> jnp.ndarray:
185185
"""Returns a scalar ndarray indicating whether the input arrays are finite."""
186-
leaves = jax.tree_leaves(tree)
186+
leaves = jax.tree_util.tree_leaves(tree)
187187
if not leaves:
188188
return jnp.array(True)
189189
else:
@@ -195,7 +195,7 @@ def all_finite(tree) -> jnp.ndarray:
195195
def select_tree(pred: jnp.ndarray, a: T, b: T) -> T:
196196
"""Selects a pytree based on the given predicate."""
197197
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
198-
return jax.tree_map(functools.partial(jax.lax.select, pred), a, b)
198+
return jax.tree_util.tree_map(functools.partial(jax.lax.select, pred), a, b)
199199

200200

201201
def warn_if_not_floating(x: Union[jnp.ndarray, object], var_name: str) -> None:

jmp/_src/loss_scale_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Tests for jmp._src.loss_scale."""
1616

17+
import warnings
18+
1719
from absl.testing import absltest
1820
from absl.testing import parameterized
1921
import jax
@@ -51,15 +53,20 @@ def test_static_loss_scale(self, cls, scale):
5153
)
5254
def test_static_empty_trees(self, create):
5355
loss_scale = create()
54-
self.assertEmpty(jax.tree_leaves(loss_scale))
56+
self.assertEmpty(jax.tree_util.tree_leaves(loss_scale))
57+
58+
def test_dynamic_loss_scale_no_warnings(self):
59+
with warnings.catch_warnings(record=True) as logged_warnings:
60+
jmp.DynamicLossScale(2. ** 15)
61+
self.assertEmpty(logged_warnings)
5562

5663
def test_dynamic_loss_scale_tree(self):
5764
scale = jnp.ones([])
5865
counter = jnp.zeros([], jnp.int32)
5966
period = 2000
6067
factor = 2
6168
loss_scale = jmp.DynamicLossScale(scale, counter, period, factor)
62-
self.assertEqual(jax.tree_leaves(loss_scale), [scale, counter])
69+
self.assertEqual(jax.tree_util.tree_leaves(loss_scale), [scale, counter])
6370
self.assertEqual(jax.tree_util.tree_map(lambda x: x, loss_scale),
6471
loss_scale)
6572

jmp/_src/policy_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def skip_if_unsupported(dtype):
5454
class PolicyTest(parameterized.TestCase):
5555

5656
def assert_dtypes_equal(self, tree_a, tree_b):
57-
jax.tree_map(lambda a, b: self.assertEqual(a.dtype, b.dtype), tree_a,
58-
tree_b)
57+
jax.tree_util.tree_map(lambda a, b: self.assertEqual(a.dtype, b.dtype),
58+
tree_a, tree_b)
5959

6060
@parameterized.parameters(*it.product(DTYPES, NUMPYS))
6161
def test_policy_cast_to_param(self, dtype, np_):

0 commit comments

Comments
 (0)