@@ -105,7 +105,7 @@ class DynamicLossScale:
105105
106106 Typical usage of this class will be something like:
107107
108- >>> loss_scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
108+ >>> loss_scale = jmp.DynamicLossScale(jnp.asarray(2. ** 15))
109109 >>> for _ in range(num_steps):
110110 ... # compute loss
111111 ... loss = loss_scale.scale(loss)
@@ -121,7 +121,7 @@ class DynamicLossScale:
121121 period : int = 2000
122122 factor : int = 2
123123 min_loss_scale : jnp .ndarray = dataclasses .field (
124- default_factory = lambda : np .ones ([], np .int32 ))
124+ default_factory = lambda : np .ones ([], np .float32 ))
125125
126126 def __post_init__ (self ) -> None :
127127 warn_if_not_floating (self .loss_scale , "loss_scale" )
@@ -183,7 +183,7 @@ def adjust(self, grads_finite: jnp.ndarray) -> "DynamicLossScale":
183183
184184def all_finite (tree ) -> jnp .ndarray :
185185 """Returns a scalar ndarray indicating whether the input arrays are finite."""
186- leaves = jax .tree_leaves (tree )
186+ leaves = jax .tree_util . tree_leaves (tree )
187187 if not leaves :
188188 return jnp .array (True )
189189 else :
@@ -195,7 +195,7 @@ def all_finite(tree) -> jnp.ndarray:
195195def select_tree (pred : jnp .ndarray , a : T , b : T ) -> T :
196196 """Selects a pytree based on the given predicate."""
197197 assert pred .ndim == 0 and pred .dtype == jnp .bool_ , "expected boolean scalar"
198- return jax .tree_map (functools .partial (jax .lax .select , pred ), a , b )
198+ return jax .tree_util . tree_map (functools .partial (jax .lax .select , pred ), a , b )
199199
200200
201201def warn_if_not_floating (x : Union [jnp .ndarray , object ], var_name : str ) -> None :
0 commit comments