We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 566ffb1 commit 977e53aCopy full SHA for 977e53a
kfac_jax/_src/tag_graph_matcher.py
@@ -420,7 +420,7 @@ def make_jax_graph(
420
closed_jaxpr = merge_broadcasts_jaxpr(closed_jaxpr)
421
closed_jaxpr = clean_jaxpr(closed_jaxpr)
422
423
- in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars)
+ in_vars = jax.tree_util.tree_unflatten(in_tree, closed_jaxpr.jaxpr.invars) # pytype:disable=attribute-error
424
425
if isinstance(params_index, int):
426
params_vars = in_vars[params_index]
0 commit comments