Skip to content

Commit c02708a

Browse files
Jake VanderPlasDistraxDev
authored andcommitted
Migrate from jax.core to jax.extend.core for several deprecated symbols
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core. PiperOrigin-RevId: 706180277
1 parent ddf4c7e commit c02708a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

distrax/_src/utils/transformations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
from absl import logging
5353
import jax
54+
import jax.extend as jex
5455
import jax.numpy as jnp
5556

5657
# pylint: disable=g-import-not-at-top
@@ -156,7 +157,7 @@ def is_constant_jacobian(fn, x=0.0):
156157
jac_jaxpr = jax.make_jaxpr(jac_fn)(jnp.array(x)).jaxpr
157158
dependent_vars = _dependent_variables(jac_jaxpr)
158159

159-
jac_is_constant = not any(isinstance(v, jax.core.Var) and v in dependent_vars
160+
jac_is_constant = not any(isinstance(v, jex.core.Var) and v in dependent_vars
160161
for v in jac_jaxpr.outvars)
161162

162163
return jac_is_constant
@@ -202,7 +203,7 @@ def _dependent_variables(jaxpr, dependent=None):
202203
if v in subjaxpr_dependent)
203204
else:
204205
for v in eqn.invars:
205-
if isinstance(v, jax.core.Var) and v in dependent:
206+
if isinstance(v, jex.core.Var) and v in dependent:
206207
dependent.update(eqn.outvars)
207208

208209
return dependent
@@ -226,20 +227,20 @@ def _identify_variable_in_eqn(eqn):
226227
var_idx = 0
227228

228229
elif len(eqn.invars) == 2: # binary operation
229-
if tuple(map(type, eqn.invars)) == (jax.core.Var, jax.core.Literal):
230+
if tuple(map(type, eqn.invars)) == (jex.core.Var, jex.core.Literal):
230231
var_idx = 0
231232

232-
elif tuple(map(type, eqn.invars)) == (jax.core.Literal, jax.core.Var):
233+
elif tuple(map(type, eqn.invars)) == (jex.core.Literal, jex.core.Var):
233234
var_idx = 1
234235

235-
elif tuple(map(type, eqn.invars)) == (jax.core.Var, jax.core.Var):
236+
elif tuple(map(type, eqn.invars)) == (jex.core.Var, jex.core.Var):
236237
raise NotImplementedError(
237238
"Expressions with multiple occurrences of the input variable are "
238239
"not supported. Please rearrange such that the variable appears only "
239240
"once in the expression if possible. If not possible, consider "
240241
"providing both `forward` and `inverse` to Lambda explicitly.")
241242

242-
elif tuple(map(type, eqn.invars)) == (jax.core.Literal, jax.core.Literal):
243+
elif tuple(map(type, eqn.invars)) == (jex.core.Literal, jex.core.Literal):
243244
raise ValueError("Expression appears to contain no variables and "
244245
"therefore cannot be inverted.")
245246

@@ -259,7 +260,7 @@ def _interpret_inverse(jaxpr, consts, *args):
259260
env = {}
260261

261262
def read(var):
262-
return var.val if isinstance(var, jax.core.Literal) else env[var]
263+
return var.val if isinstance(var, jex.core.Literal) else env[var]
263264
def write(var, val):
264265
env[var] = val
265266

0 commit comments

Comments
 (0)