|
4 | 4 | import jax |
5 | 5 | from jax.api_util import flatten_fun, shaped_abstractify |
6 | 6 | from jax.experimental.pjit import pjit_p |
7 | | -import jax.util as util |
8 | 7 |
|
9 | 8 | try: |
10 | 9 | import jax.extend.linear_util as lu |
|
30 | 29 | from jax.interpreters.pxla import xla_pmap_p |
31 | 30 |
|
32 | 31 |
|
| 32 | +# Adapted from definition in jax v0.5.0 |
| 33 | +def _safe_map(f, *args): |
| 34 | + args = list(map(list, args)) |
| 35 | + n = len(args[0]) |
| 36 | + for arg in args[1:]: |
| 37 | + assert len(arg) == n, f'length mismatch: {list(map(len, args))}' |
| 38 | + return list(map(f, *args)) |
| 39 | + |
| 40 | + |
33 | 41 | def eval_provenance(fn, **kwargs): |
34 | 42 | """ |
35 | 43 | Compute the provenance output of ``fun`` using JAX's abstract |
@@ -60,7 +68,7 @@ def eval_provenance(fn, **kwargs): |
60 | 68 | ) |
61 | 69 | wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn, **fn_info), in_tree) |
62 | 70 | # Abstract eval to get output pytree |
63 | | - avals = util.safe_map(shaped_abstractify, args) |
| 71 | + avals = _safe_map(shaped_abstractify, args) |
64 | 72 | # XXX: we split out the process of abstract evaluation and provenance tracking |
65 | 73 | # for simplicity. In principle, they can be merged so that we only need to walk |
66 | 74 | # through the equations once. |
@@ -102,14 +110,14 @@ def write(v, p): |
102 | 110 | return |
103 | 111 | env[v] = read(v) | p |
104 | 112 |
|
105 | | - util.safe_map(write, jaxpr.invars, provenance_inputs) |
| 113 | + _safe_map(write, jaxpr.invars, provenance_inputs) |
106 | 114 | for eqn in jaxpr.eqns: |
107 | | - provenance_inputs = util.safe_map(read, eqn.invars) |
| 115 | + provenance_inputs = _safe_map(read, eqn.invars) |
108 | 116 | rule = track_deps_rules.get(eqn.primitive, _default_track_deps_rules) |
109 | 117 | provenance_outputs = rule(eqn, provenance_inputs) |
110 | | - util.safe_map(write, eqn.outvars, provenance_outputs) |
| 118 | + _safe_map(write, eqn.outvars, provenance_outputs) |
111 | 119 |
|
112 | | - return util.safe_map(read, jaxpr.outvars) |
| 120 | + return _safe_map(read, jaxpr.outvars) |
113 | 121 |
|
114 | 122 |
|
115 | 123 | track_deps_rules = {} |
|
0 commit comments