5151
5252from absl import logging
5353import jax
54+ import jax .extend as jex
5455import jax .numpy as jnp
5556
5657# pylint: disable=g-import-not-at-top
@@ -156,7 +157,7 @@ def is_constant_jacobian(fn, x=0.0):
156157 jac_jaxpr = jax .make_jaxpr (jac_fn )(jnp .array (x )).jaxpr
157158 dependent_vars = _dependent_variables (jac_jaxpr )
158159
159- jac_is_constant = not any (isinstance (v , jax .core .Var ) and v in dependent_vars
160+ jac_is_constant = not any (isinstance (v , jex .core .Var ) and v in dependent_vars
160161 for v in jac_jaxpr .outvars )
161162
162163 return jac_is_constant
@@ -202,7 +203,7 @@ def _dependent_variables(jaxpr, dependent=None):
202203 if v in subjaxpr_dependent )
203204 else :
204205 for v in eqn .invars :
205- if isinstance (v , jax .core .Var ) and v in dependent :
206+ if isinstance (v , jex .core .Var ) and v in dependent :
206207 dependent .update (eqn .outvars )
207208
208209 return dependent
@@ -226,20 +227,20 @@ def _identify_variable_in_eqn(eqn):
226227 var_idx = 0
227228
228229 elif len (eqn .invars ) == 2 : # binary operation
229- if tuple (map (type , eqn .invars )) == (jax .core .Var , jax .core .Literal ):
230+ if tuple (map (type , eqn .invars )) == (jex .core .Var , jex .core .Literal ):
230231 var_idx = 0
231232
232- elif tuple (map (type , eqn .invars )) == (jax .core .Literal , jax .core .Var ):
233+ elif tuple (map (type , eqn .invars )) == (jex .core .Literal , jex .core .Var ):
233234 var_idx = 1
234235
235- elif tuple (map (type , eqn .invars )) == (jax .core .Var , jax .core .Var ):
236+ elif tuple (map (type , eqn .invars )) == (jex .core .Var , jex .core .Var ):
236237 raise NotImplementedError (
237238 "Expressions with multiple occurrences of the input variable are "
238239 "not supported. Please rearrange such that the variable appears only "
239240 "once in the expression if possible. If not possible, consider "
240241 "providing both `forward` and `inverse` to Lambda explicitly." )
241242
242- elif tuple (map (type , eqn .invars )) == (jax .core .Literal , jax .core .Literal ):
243+ elif tuple (map (type , eqn .invars )) == (jex .core .Literal , jex .core .Literal ):
243244 raise ValueError ("Expression appears to contain no variables and "
244245 "therefore cannot be inverted." )
245246
@@ -259,7 +260,7 @@ def _interpret_inverse(jaxpr, consts, *args):
259260 env = {}
260261
261262 def read (var ):
262- return var .val if isinstance (var , jax .core .Literal ) else env [var ]
263+ return var .val if isinstance (var , jex .core .Literal ) else env [var ]
263264 def write (var , val ):
264265 env [var ] = val
265266
0 commit comments