Skip to content

Commit fb183b4

Browse files
committed
Avoid importing deprecated jax.util module
1 parent 448ea3f commit fb183b4

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

numpyro/ops/provenance.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import jax
55
from jax.api_util import flatten_fun, shaped_abstractify
66
from jax.experimental.pjit import pjit_p
7-
import jax.util as util
87

98
try:
109
import jax.extend.linear_util as lu
@@ -30,6 +29,15 @@
3029
from jax.interpreters.pxla import xla_pmap_p
3130

3231

32+
# Adapted from definition in jax v0.5.0
33+
def _safe_map(f, *args):
34+
args = list(map(list, args))
35+
n = len(args[0])
36+
for arg in args[1:]:
37+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
38+
return list(map(f, *args))
39+
40+
3341
def eval_provenance(fn, **kwargs):
3442
"""
3543
Compute the provenance output of ``fun`` using JAX's abstract
@@ -60,7 +68,7 @@ def eval_provenance(fn, **kwargs):
6068
)
6169
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn, **fn_info), in_tree)
6270
# Abstract eval to get output pytree
63-
avals = util.safe_map(shaped_abstractify, args)
71+
avals = _safe_map(shaped_abstractify, args)
6472
# XXX: we split out the process of abstract evaluation and provenance tracking
6573
# for simplicity. In principle, they can be merged so that we only need to walk
6674
# through the equations once.
@@ -102,14 +110,14 @@ def write(v, p):
102110
return
103111
env[v] = read(v) | p
104112

105-
util.safe_map(write, jaxpr.invars, provenance_inputs)
113+
_safe_map(write, jaxpr.invars, provenance_inputs)
106114
for eqn in jaxpr.eqns:
107-
provenance_inputs = util.safe_map(read, eqn.invars)
115+
provenance_inputs = _safe_map(read, eqn.invars)
108116
rule = track_deps_rules.get(eqn.primitive, _default_track_deps_rules)
109117
provenance_outputs = rule(eqn, provenance_inputs)
110-
util.safe_map(write, eqn.outvars, provenance_outputs)
118+
_safe_map(write, eqn.outvars, provenance_outputs)
111119

112-
return util.safe_map(read, jaxpr.outvars)
120+
return _safe_map(read, jaxpr.outvars)
113121

114122

115123
track_deps_rules = {}

0 commit comments

Comments
 (0)