Skip to content

Commit 18c2cc1

Browse files
Fix xla_pmap_p import for JAX versions that removed pmap (#2173)
JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a recent release. Guard the import so numpyro works with both old and new JAX versions. Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com>
1 parent bf4e8ef commit 18c2cc1

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

numpyro/ops/provenance.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import jax
55
from jax.api_util import debug_info, flatten_fun, shaped_abstractify
66
from jax.extend.core import Literal
7-
from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p
7+
from jax.extend.core.primitives import call_p, closed_call_p, jit_p
8+
9+
try:
10+
from jax.extend.core.primitives import xla_pmap_p
11+
except ImportError:
12+
xla_pmap_p = None
813
import jax.extend.linear_util as lu
914
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
1015

@@ -114,7 +119,8 @@ def track_deps_call_rule(eqn, provenance_inputs):
114119

115120

116121
track_deps_rules[call_p] = track_deps_call_rule
117-
track_deps_rules[xla_pmap_p] = track_deps_call_rule
122+
if xla_pmap_p is not None:
123+
track_deps_rules[xla_pmap_p] = track_deps_call_rule
118124

119125

120126
def track_deps_closed_call_rule(eqn, provenance_inputs):

0 commit comments

Comments
 (0)