Skip to content

Commit 175cace

Browse files
suryabhupaDistraxDev
authored andcommitted
Replace deprecated jax.linear_util.wrap_init with jax.extend.linear_util.wrap_init.
PiperOrigin-RevId: 562017973
1 parent 93c54a8 commit 175cace

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

distrax/_src/utils/transformations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def write(var, val):
258258
# if primitive is an xla_call, get subexpressions and evaluate recursively
259259
call_jaxpr, params = _extract_call_jaxpr(eqn.primitive, params)
260260
if call_jaxpr:
261-
subfuns = [jax.linear_util.wrap_init(
261+
subfuns = [jax.extend.linear_util.wrap_init(
262262
functools.partial(_interpret_inverse, call_jaxpr, ()))]
263263
prim_inv = eqn.primitive
264264

0 commit comments

Comments
 (0)