Commit 18c2cc1
Fix xla_pmap_p import for JAX versions that removed pmap (#2173)
JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a
recent release. Guard the import so numpyro works with both old and
new JAX versions.
Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com>1 parent bf4e8ef commit 18c2cc1
1 file changed
Lines changed: 8 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
8 | 13 | | |
9 | 14 | | |
10 | 15 | | |
| |||
114 | 119 | | |
115 | 120 | | |
116 | 121 | | |
117 | | - | |
| 122 | + | |
| 123 | + | |
118 | 124 | | |
119 | 125 | | |
120 | 126 | | |
| |||
0 commit comments