Skip to content

Commit 977e53a

Browse files
james-martensKfacJaxDev
authored andcommitted
Fixing a PyType issue caused by recent JAX CL.
PiperOrigin-RevId: 828956396
1 parent 566ffb1 commit 977e53a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

kfac_jax/_src/tag_graph_matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def make_jax_graph(
420420
closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr)
421421
closed_jaxpr = clean_jaxpr(closed_jaxpr)
422422

423-
in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars)
423+
in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars) # pytype:disable=attribute-error
424424

425425
if isinstance(params_index, int):
426426
params_vars = in_vars[params_index]

0 commit comments

Comments
 (0)