We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 93c54a8 commit 175caceCopy full SHA for 175cace
distrax/_src/utils/transformations.py
@@ -258,7 +258,7 @@ def write(var, val):
258
# if primitive is an xla_call, get subexpressions and evaluate recursively
259
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params)
260
if call_jaxpr:
261
- subfuns = [jax.linear_util.wrap_init(
+ subfuns = [jax.extend.linear_util.wrap_init(
262
functools.partial(_interpret_inverse, call_jaxpr, ()))]
263
prim_inv = eqn.primitive
264
0 commit comments